|
1 | 1 | import unittest |
2 | 2 |
|
3 | 3 | import torch |
4 | | -from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps |
| 4 | +from executorch.backends.qualcomm._passes import ( |
| 5 | + ConvertBmmToMatmul, |
| 6 | + ConvertMhaToSha, |
| 7 | + InsertReshapeForReduceOps, |
| 8 | + RemoveRedundancy, |
| 9 | +) |
| 10 | + |
| 11 | +from executorch.exir import to_edge |
| 12 | +from executorch.exir.dialects._ops import ops as exir_ops |
5 | 13 |
|
6 | 14 |
|
7 | 15 | class TestPasses(unittest.TestCase): |
@@ -49,6 +57,90 @@ def forward(self, x): |
49 | 57 | torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}" |
50 | 58 | ) |
51 | 59 |
|
| 60 | + def test_mha_to_sha(self): |
| 61 | + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d |
| 62 | + from executorch.examples.models.llama.model_args import ModelArgs |
| 63 | + from executorch.examples.qualcomm.oss_scripts.llama.masking_utils import ( |
| 64 | + CausalAttentionMask, |
| 65 | + ) |
| 66 | + from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( |
| 67 | + LlamaAttention, |
| 68 | + ) |
| 69 | + |
| 70 | + # Initailize model config |
| 71 | + args = ModelArgs() |
| 72 | + args.max_seq_len = 128 |
| 73 | + args.ar_len = 32 |
| 74 | + args.use_kv_cache = True |
| 75 | + args.dim = 32 |
| 76 | + args.n_heads = 8 |
| 77 | + args.n_kv_heads = 8 |
| 78 | + args.n_layers = 2 |
| 79 | + args.head_dim = args.dim // args.n_heads |
| 80 | + mod = convert_linear_to_conv2d(LlamaAttention(0, args, True)) |
| 81 | + |
| 82 | + # Prepare inputs |
| 83 | + hidden_states = torch.randint( |
| 84 | + low=0, |
| 85 | + high=100, |
| 86 | + size=(args.max_batch_size, args.ar_len, args.dim), |
| 87 | + dtype=torch.float32, |
| 88 | + ) |
| 89 | + freqs_cos = torch.randn(args.ar_len, 1) |
| 90 | + freqs_sin = torch.randn(args.ar_len, 1) |
| 91 | + atten_mask = CausalAttentionMask( |
| 92 | + args.max_batch_size, args.ar_len, args.max_seq_len |
| 93 | + ) |
| 94 | + k_cache = torch.zeros( |
| 95 | + args.max_batch_size, |
| 96 | + args.n_kv_heads, |
| 97 | + args.head_dim, |
| 98 | + args.max_seq_len - args.ar_len, |
| 99 | + ) |
| 100 | + |
| 101 | + v_cache = torch.zeros( |
| 102 | + args.max_batch_size, |
| 103 | + args.n_kv_heads, |
| 104 | + args.max_seq_len - args.ar_len, |
| 105 | + args.head_dim, |
| 106 | + ) |
| 107 | + sample_input = ( |
| 108 | + hidden_states, |
| 109 | + freqs_cos, |
| 110 | + freqs_sin, |
| 111 | + atten_mask.mask, |
| 112 | + k_cache, |
| 113 | + v_cache, |
| 114 | + ) |
| 115 | + |
| 116 | + # Export the module and convert linear to conv2d |
| 117 | + edge_program = to_edge(torch.export.export(mod, sample_input)) |
| 118 | + new_ep = edge_program.exported_program() |
| 119 | + |
| 120 | + conv_nodes = [ |
| 121 | + n |
| 122 | + for n in new_ep.graph.nodes |
| 123 | + if n.target == exir_ops.edge.aten.convolution.default |
| 124 | + ] |
| 125 | + # WQ, WK, WV, O |
| 126 | + self.assertTrue(len(conv_nodes) == 4, "Convolution nodes missing") |
| 127 | + |
| 128 | + # Convert MHA to SHA |
| 129 | + # This is a simplified version of what happens in the full pipeline to test the core functionality |
| 130 | + graph_module = RemoveRedundancy(quantization_capture=False)( |
| 131 | + new_ep.graph_module |
| 132 | + ).graph_module |
| 133 | + graph_module = ConvertBmmToMatmul()(graph_module).graph_module |
| 134 | + graph_module = ConvertMhaToSha(new_ep)(graph_module).graph_module |
| 135 | + |
| 136 | + conv_nodes = [ |
| 137 | + n |
| 138 | + for n in new_ep.graph.nodes |
| 139 | + if n.target == exir_ops.edge.aten.convolution.default |
| 140 | + ] |
| 141 | + # Check graph structure: WQ, WK, WV should be converted to SHA |
| 142 | + self.assertTrue(len(conv_nodes) == 25, "Convolution nodes should be splited") |
| 143 | + |
52 | 144 |
|
53 | 145 | if __name__ == "__main__": |
54 | 146 | unittest.main() |
0 commit comments