Skip to content

Commit aeee0fa

Browse files
committed
Address PR feedback
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 21a1594 commit aeee0fa

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212

1313
Dim = Union[int, ir.SymbolicDim]
1414

15+
# This file contains a fusion rule that recognizes various patterns of scaled dot-product attention
16+
# (SDPA) implementations and replaces them with a single SDPA op. The SDPA op is a temporary fusion
17+
# op defined in the ai.onnxruntime._fusion domain. Subsequent fusion rules will map it into one
18+
# of the various ops defined in ORT: MHA, GQA, or Attention depending on the input patterns.
19+
# The SDPA is a standard scalar dot-product attention with an optional mask input and scaling factor.
20+
# Currently, it is restricted to query, key, and values of rank 4 with shapes:
21+
# Query: [batch_size, num_heads, seq_len, head_size_qk]
22+
# Key: [batch_size, num_heads, seq_len_kv, head_size_qk]
23+
# or [batch_size, seq_len_kv, num_heads, head_size_qk])
24+
# Value: [batch_size, num_heads, seq_len_kv, head_size_v]
25+
# The key_format attribute indicates which of the two formats the key uses and can be either "BHSd" or "BSHd".
26+
1527

1628
class SDPA(pattern.RewriteRuleClassBase):
1729
_scale: float | None

onnxscript/rewriter/ort_fusions/sdpa_via_mha.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
class SDPAImplementation(pattern.RewriteRuleClassBase):
1616
def pattern(self, op, query, key, value, key_format):
17+
"""Pattern matches any call to SDPA. See sdpa.py for documentation on the SDPA op."""
1718
return op.SDPA(
1819
query,
1920
key,

0 commit comments

Comments
 (0)