Skip to content

Commit a666afa

Browse files
committed
Add a test to check functionality of ConvertMhaToSha
1 parent b199e88 commit a666afa

File tree

3 files changed

+102
-4
lines changed

3 files changed

+102
-4
lines changed

backends/qualcomm/_passes/convert_bmm_to_matmul.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ def call(self, graph_module: torch.fx.GraphModule):
4747
graph = graph_module.graph
4848
partitions = get_source_partitions(
4949
graph,
50-
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
50+
[
51+
"matmul",
52+
operator.matmul,
53+
torch.matmul,
54+
torch.bmm,
55+
torch.ops.aten.matmul.default,
56+
],
5157
)
5258
for _, src_partitions in partitions.items():
5359
for src_partition in src_partitions:

backends/qualcomm/_passes/convert_mha_to_sha.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _get_attention_output(self, softmax):
126126
pattern_qk = [_is_softmax, "*", lambda x: _is_matmul(x) or _is_bmm(x)]
127127
qk = find_pattern(softmax, pattern_qk)
128128
if not qk:
129-
return None, None
129+
return None, None, None
130130

131131
patterns_qkv = [
132132
_is_softmax,
@@ -139,7 +139,7 @@ def _get_attention_output(self, softmax):
139139

140140
qkv = find_pattern(softmax, patterns_qkv, from_args=False)
141141
if qkv is None:
142-
return None, None
142+
return None, None, None
143143

144144
permute, reshape = qkv[0][-2:]
145145
matmul = qkv[0][2]

backends/qualcomm/tests/test_passes.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import unittest
22

33
import torch
4-
from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps
4+
from executorch.backends.qualcomm._passes import (
5+
ConvertBmmToMatmul,
6+
ConvertMhaToSha,
7+
InsertReshapeForReduceOps,
8+
RemoveRedundancy,
9+
)
10+
11+
from executorch.exir import to_edge
12+
from executorch.exir.dialects._ops import ops as exir_ops
513

614

715
class TestPasses(unittest.TestCase):
@@ -49,6 +57,90 @@ def forward(self, x):
4957
torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}"
5058
)
5159

60+
def test_mha_to_sha(self):
61+
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
62+
from executorch.examples.models.llama.model_args import ModelArgs
63+
from executorch.examples.qualcomm.oss_scripts.llama.masking_utils import (
64+
CausalAttentionMask,
65+
)
66+
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
67+
LlamaAttention,
68+
)
69+
70+
# Initailize model config
71+
args = ModelArgs()
72+
args.max_seq_len = 128
73+
args.ar_len = 32
74+
args.use_kv_cache = True
75+
args.dim = 32
76+
args.n_heads = 8
77+
args.n_kv_heads = 8
78+
args.n_layers = 2
79+
args.head_dim = args.dim // args.n_heads
80+
mod = convert_linear_to_conv2d(LlamaAttention(0, args, True))
81+
82+
# Prepare inputs
83+
hidden_states = torch.randint(
84+
low=0,
85+
high=100,
86+
size=(args.max_batch_size, args.ar_len, args.dim),
87+
dtype=torch.float32,
88+
)
89+
freqs_cos = torch.randn(args.ar_len, 1)
90+
freqs_sin = torch.randn(args.ar_len, 1)
91+
atten_mask = CausalAttentionMask(
92+
args.max_batch_size, args.ar_len, args.max_seq_len
93+
)
94+
k_cache = torch.zeros(
95+
args.max_batch_size,
96+
args.n_kv_heads,
97+
args.head_dim,
98+
args.max_seq_len - args.ar_len,
99+
)
100+
101+
v_cache = torch.zeros(
102+
args.max_batch_size,
103+
args.n_kv_heads,
104+
args.max_seq_len - args.ar_len,
105+
args.head_dim,
106+
)
107+
sample_input = (
108+
hidden_states,
109+
freqs_cos,
110+
freqs_sin,
111+
atten_mask.mask,
112+
k_cache,
113+
v_cache,
114+
)
115+
116+
# Export the module and convert linear to conv2d
117+
edge_program = to_edge(torch.export.export(mod, sample_input))
118+
new_ep = edge_program.exported_program()
119+
120+
conv_nodes = [
121+
n
122+
for n in new_ep.graph.nodes
123+
if n.target == exir_ops.edge.aten.convolution.default
124+
]
125+
# WQ, WK, WV, O
126+
self.assertTrue(len(conv_nodes) == 4, "Convolution nodes missing")
127+
128+
# Convert MHA to SHA
129+
# This is a simplified version of what happens in the full pipeline to test the core functionality
130+
graph_module = RemoveRedundancy(quantization_capture=False)(
131+
new_ep.graph_module
132+
).graph_module
133+
graph_module = ConvertBmmToMatmul()(graph_module).graph_module
134+
graph_module = ConvertMhaToSha(new_ep)(graph_module).graph_module
135+
136+
conv_nodes = [
137+
n
138+
for n in new_ep.graph.nodes
139+
if n.target == exir_ops.edge.aten.convolution.default
140+
]
141+
# Check graph structure: WQ, WK, WV should be converted to SHA
142+
self.assertTrue(len(conv_nodes) == 25, "Convolution nodes should be splited")
143+
52144

53145
if __name__ == "__main__":
54146
unittest.main()

0 commit comments

Comments
 (0)