Skip to content

Commit 09d57aa

Browse files
committed
[Executorch][llm] Make mask tensor float only for sdpa
Now that we support quantized sdpa query tensor can be quantized and attention mask can be float (the only type allowed). So this check doesnt make sense anymore. Differential Revision: [D77516821](https://our.internmc.facebook.com/intern/diff/D77516821/) ghstack-source-id: 293661338 Pull Request resolved: #12131
1 parent 7b9ab92 commit 09d57aa

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ bool validate_flash_attention_args(
5959

6060
ET_CHECK_OR_RETURN_FALSE(
6161
!attn_mask.has_value() ||
62-
attn_mask.value().scalar_type() == query.scalar_type(),
63-
"Attention mask must be a 2D tensor");
62+
attn_mask.value().scalar_type() == ScalarType::Float,
63+
"Attention mask must be a Float tensor");
6464

6565
ET_CHECK_OR_RETURN_FALSE(
6666
is_contiguous_dim_order(query.dim_order().data(), query.dim()),

0 commit comments

Comments
 (0)