Skip to content

Commit 21a1594

Browse files
committed
Handle swapped key format
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent e25ef84 commit 21a1594

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,20 +292,41 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
292292
return attn_output
293293

294294

295+
# This tests a scenario where the key is in BSHd format instead of BHSd, which
296+
# happens due to an optimization that fuses two transposes together, the one
297+
# to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first
298+
# transpose down below is different from other test cases.
299+
@script()
300+
def _unmasked_pre_div_sdpa_BSHd_key_script(query, key, value):
301+
key_transposed = op.Transpose(key, perm=[0, 2, 3, 1]) # BSHd to BHdS
302+
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
303+
scaled_query = op.Div(query, divisor)
304+
scaled_key = op.Div(key_transposed, divisor)
305+
attn_score = op.MatMul(scaled_query, scaled_key)
306+
attn_weight = op.Softmax(attn_score, axis=-1)
307+
is_nan = op.IsNaN(attn_weight)
308+
zero = op.Constant(value_float=0.0)
309+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
310+
attn_output = op.MatMul(adj_attn_weight, value)
311+
return attn_output
312+
313+
295314
class SDPATestCase:
296-
def __init__(self, script_func, *, with_mask):
315+
def __init__(self, script_func, *, with_mask, BSHd_key=False):
297316
self.script_func = script_func
298317
self.with_mask = with_mask
318+
self.BSHd_key = BSHd_key
299319

300320
def get_onnx_model(self):
301321
if not hasattr(self, "_onnx_model"):
302-
qkv_type = FLOAT[B, N, S, H]
322+
qv_type = FLOAT[B, N, S, H]
303323
mask_type = FLOAT[B, N, S, S]
304-
input_types = [qkv_type, qkv_type, qkv_type]
324+
k_type = FLOAT[B, S, N, H] if self.BSHd_key else FLOAT[B, N, S, H]
325+
input_types = [qv_type, k_type, qv_type]
305326
if self.with_mask:
306327
input_types.append(mask_type)
307328
model_proto = self.script_func.to_model_proto(
308-
input_types=input_types, output_types=[qkv_type]
329+
input_types=input_types, output_types=[qv_type]
309330
)
310331
self._onnx_model = ir.serde.deserialize_model(model_proto)
311332
return self._onnx_model
@@ -314,7 +335,9 @@ def get_ort_inputs(self):
314335
if not hasattr(self, "_ort_inputs"):
315336
inputs = {
316337
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
317-
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
338+
"key": numpy.random.rand(B, S, N, H).astype(numpy.float32)
339+
if self.BSHd_key
340+
else numpy.random.rand(B, N, S, H).astype(numpy.float32),
318341
"value": numpy.random.rand(B, N, S, H).astype(numpy.float32),
319342
}
320343
if self.with_mask:
@@ -374,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase):
374397
"_custom_multi_scale_pre_mul_sdpa_script",
375398
_custom_multi_scale_pre_mul_sdpa_script,
376399
),
400+
("pre_div_sdpa_BSHd_key", _unmasked_pre_div_sdpa_BSHd_key_script),
377401
]
378402
)
379403
def test_sdpa_fusion(self, name, script_func):
380-
test_case = SDPATestCase(script_func, with_mask="masked" in name)
404+
test_case = SDPATestCase(
405+
script_func, with_mask="masked" in name, BSHd_key="BSHd_key" in name
406+
)
381407
model = test_case.get_onnx_model()
382408
onnxscript.optimizer.optimize(model)
383409

onnxscript/rewriter/ort_fusions/sdpa_via_mha.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,56 @@
77
import onnx_ir as ir
88

99
from onnxscript.rewriter import _fusion_utils, pattern
10+
from onnxscript.rewriter._basics import MatchFailureError
1011

1112
Dim = Union[int, ir.SymbolicDim]
1213

1314

1415
class SDPAImplementation(pattern.RewriteRuleClassBase):
15-
def pattern(self, op, query, key, value):
16+
def pattern(self, op, query, key, value, key_format):
1617
return op.SDPA(
1718
query,
1819
key,
1920
value,
20-
key_format="BHSd",
21+
key_format=key_format,
2122
_allow_other_inputs=True, # Mask is optional
2223
_outputs=["sdpa_output"],
2324
_domain="ai.onnxruntime._fusion",
2425
)
2526

26-
def check(self, context, query, key, value, sdpa_output):
27+
def check(self, context, query, key, value, key_format, sdpa_output):
2728
bindings: dict[str, Dim] = {}
2829
_fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"])
29-
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
3030
_fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"])
3131

32+
if key_format.value == "BHSd":
33+
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
34+
elif key_format.value == "BSHd":
35+
_fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"])
36+
else:
37+
raise MatchFailureError(
38+
f"Unexpected key_format value: {key_format.value}", key_format
39+
)
40+
3241
self._num_heads = bindings["H"]
3342
if not isinstance(self._num_heads, int):
3443
return False
3544
self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed
3645
return isinstance(self._num_heads, int)
3746

38-
def rewrite(self, op, query, key, value, sdpa_output):
47+
def rewrite(self, op, query, key, value, key_format, sdpa_output):
3948
sdpa_node = sdpa_output.producer()
4049
scale = sdpa_node.attributes.get("scale", None)
4150
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
4251
to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1])
4352
query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape)
44-
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
4553
value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape)
4654

55+
if key_format.value == "BHSd":
56+
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
57+
else: # BSHd
58+
key_3d = op.Reshape(key, to_3d_shape)
59+
4760
inputs = [query_3d, key_3d, value_3d]
4861
if len(sdpa_node.inputs) > 3:
4962
mask = sdpa_node.inputs[3]

0 commit comments

Comments
 (0)