Skip to content

Commit ce6ba0b

Browse files
committed
better error message
1 parent 37a861f commit ce6ba0b

File tree

1 file changed

+2
-1
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+2
-1
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1840,7 +1840,8 @@ def aten_scaled_dot_product_attention(
18401840
key, value = _attention_repeat_kv_for_group_query(query, key, value)
18411841
else:
18421842
assert query.shape[1] == key.shape[1] == value.shape[1], (
1843-
"SDPA (MHA) requires q_num_heads = kv_num_heads"
1843+
"SDPA (MHA) requires q_num_heads = kv_num_heads, "
1844+
f"query.shape={query.shape}, key.shape{key.shape}, value.shape={value.shape}"
18441845
)
18451846

18461847
if attn_mask is None:

0 commit comments

Comments
 (0)