We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7fbe280 commit c12fd97Copy full SHA for c12fd97
extension/llm/custom_ops/op_sdpa.cpp
@@ -400,7 +400,8 @@ Tensor& custom_sdpa_out_impl(
400
401
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
402
403
- const int64_t num_keys_for_causal_attention = start_pos + seq_len;
+ const int64_t num_keys_for_causal_attention =
404
+ attn_mask.has_value() ? -1 : start_pos + seq_len;
405
406
ET_KERNEL_CHECK(
407
ctx,
0 commit comments