diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 91802a8445d..c98fa1729fa 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -59,8 +59,8 @@ bool validate_flash_attention_args( ET_CHECK_OR_RETURN_FALSE( !attn_mask.has_value() || - attn_mask.value().scalar_type() == query.scalar_type(), - "Attention mask must be a 2D tensor"); + attn_mask.value().scalar_type() == ScalarType::Float, + "Attention mask must be a Float tensor"); ET_CHECK_OR_RETURN_FALSE( is_contiguous_dim_order(query.dim_order().data(), query.dim()),