| 
13 | 13 | 
 
  | 
14 | 14 | from executorch.backends.apple.coreml.compiler import CoreMLBackend  | 
15 | 15 | from executorch.backends.apple.coreml.partition import CoreMLPartitioner  | 
 | 16 | +from executorch.exir.backend.utils import format_delegated_graph  | 
16 | 17 | 
 
  | 
17 | 18 | 
 
  | 
18 | 19 | class TestCoreMLPartitioner(unittest.TestCase):  | 
@@ -79,6 +80,50 @@ def test_vit_skip_conv(self):  | 
79 | 80 |             "getitem",  | 
80 | 81 |         ]  | 
81 | 82 | 
 
  | 
 | 83 | +    def test_ops_to_not_decompose(self):  | 
 | 84 | +        class Model(torch.nn.Module):  | 
 | 85 | +            def forward(self, q, k, v, mask):  | 
 | 86 | +                return torch.ops.aten.scaled_dot_product_attention.default(  | 
 | 87 | +                    q, k, v, attn_mask=mask  | 
 | 88 | +                )  | 
 | 89 | + | 
 | 90 | +        model = Model()  | 
 | 91 | +        model.eval()  | 
 | 92 | + | 
 | 93 | +        batch_size = 1  | 
 | 94 | +        n_heads = 12  | 
 | 95 | +        seq_len = 1  | 
 | 96 | +        max_seq_length = 32  | 
 | 97 | +        embedding_dim = 16  | 
 | 98 | +        q = torch.randn(batch_size, n_heads, seq_len, embedding_dim)  | 
 | 99 | +        k = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)  | 
 | 100 | +        v = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim)  | 
 | 101 | +        mask = torch.randn(seq_len, max_seq_length)  | 
 | 102 | +        example_inputs = (q, k, v, mask)  | 
 | 103 | +        ep = torch.export.export(model, example_inputs)  | 
 | 104 | +        coreml_partitioner = CoreMLPartitioner()  | 
 | 105 | + | 
 | 106 | +        # Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph  | 
 | 107 | +        edge_program_manager = executorch.exir.to_edge_transform_and_lower(  | 
 | 108 | +            ep, partitioner=[coreml_partitioner]  | 
 | 109 | +        )  | 
 | 110 | +        self.assertTrue(  | 
 | 111 | +            "executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"  | 
 | 112 | +            in format_delegated_graph(  | 
 | 113 | +                edge_program_manager.exported_program().graph_module  | 
 | 114 | +            )  | 
 | 115 | +        )  | 
 | 116 | + | 
 | 117 | +        # Using to_edge flow, we expect SDPA will be decomposed and not show up in delegated graph  | 
 | 118 | +        edge_program_manager2 = executorch.exir.to_edge(ep)  | 
 | 119 | +        edge_program_manager2.to_backend(coreml_partitioner)  | 
 | 120 | +        self.assertTrue(  | 
 | 121 | +            "executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"  | 
 | 122 | +            not in format_delegated_graph(  | 
 | 123 | +                edge_program_manager2.exported_program().graph_module  | 
 | 124 | +            )  | 
 | 125 | +        )  | 
 | 126 | + | 
82 | 127 |     def test_buffer(self):  | 
83 | 128 |         embedding_dim = 3  | 
84 | 129 |         max_seq_len = 2  | 
@@ -129,4 +174,5 @@ def forward(self, q, k_val, input_pos):  | 
129 | 174 |     test_runner = TestCoreMLPartitioner()  | 
130 | 175 |     test_runner.test_add_sub_skip_mm()  | 
131 | 176 |     test_runner.test_vit_skip_conv()  | 
 | 177 | +    test_runner.test_ops_to_not_decompose()  | 
132 | 178 |     test_runner.test_buffer()  | 
0 commit comments