|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | 2 |
|
3 | 3 | # pyre-strict |
4 | | - |
5 | 4 | import torch |
6 | | -from executorch.backends.cadence.aot.program_builder import ProgramBuilder |
| 5 | +from executorch.backends.cadence.aot.program_builder import IrMode, ProgramBuilder |
| 6 | +from executorch.exir.dialects._ops import ops as exir_ops |
7 | 7 | from later.unittest import TestCase |
| 8 | +from torch._export.verifier import SpecViolationError |
8 | 9 | from torch.export.graph_signature import InputKind, OutputKind |
9 | 10 |
|
10 | 11 |
|
@@ -120,3 +121,102 @@ def test_user_input_mutation(self) -> None: |
120 | 121 | self.assertEqual( |
121 | 122 | program.graph_signature.output_specs[0].kind, OutputKind.USER_INPUT_MUTATION |
122 | 123 | ) |
| 124 | + |
| 125 | + def test_get_verifier_exir_mode(self) -> None: |
| 126 | + """Test that get_verifier returns EXIREdgeDialectVerifier for EXIR mode.""" |
| 127 | + builder = ProgramBuilder(mode=IrMode.EXIR) |
| 128 | + verifiers = builder.get_verifiers() |
| 129 | + self.assertIsNotNone(verifiers) |
| 130 | + self.assertEqual(len(verifiers), 1) |
| 131 | + |
| 132 | + def test_get_verifier_aten_mode(self) -> None: |
| 133 | + """Test that get_verifier returns None for ATEN mode.""" |
| 134 | + builder = ProgramBuilder(mode=IrMode.ATEN) |
| 135 | + verifiers = builder.get_verifiers() |
| 136 | + self.assertIsNone(verifiers) |
| 137 | + |
| 138 | + def test_get_verifier_default_mode(self) -> None: |
| 139 | + """Test that get_verifier returns EXIREdgeDialectVerifier for default mode.""" |
| 140 | + builder = ProgramBuilder() # Should default to EXIR |
| 141 | + self.assertEqual(builder.mode, IrMode.EXIR) |
| 142 | + verifiers = builder.get_verifiers() |
| 143 | + self.assertIsNotNone(verifiers) |
| 144 | + self.assertEqual(len(verifiers), 1) |
| 145 | + |
| 146 | + def test_aten_add_tensor_exir_mode(self) -> None: |
| 147 | + """Test using torch.ops.aten.add.Tensor with EXIR mode.""" |
| 148 | + inp = torch.randn([3, 5]) |
| 149 | + buffer = torch.randn([5]) |
| 150 | + |
| 151 | + builder = ProgramBuilder(mode=IrMode.EXIR) |
| 152 | + inp_proxy = builder.placeholder("inp", inp) |
| 153 | + buffer_proxy = builder.placeholder( |
| 154 | + "buffer", buffer, input_kind=InputKind.BUFFER |
| 155 | + ) |
| 156 | + add = builder.call_operator( |
| 157 | + torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy) |
| 158 | + ) |
| 159 | + builder.output([add]) |
| 160 | + builder.get_program() |
| 161 | + |
| 162 | + def test_aten_add_tensor_aten_mode(self) -> None: |
| 163 | + """Test using torch.ops.aten.add.Tensor with ATEN mode.""" |
| 164 | + inp = torch.randn([3, 5]) |
| 165 | + buffer = torch.randn([5]) |
| 166 | + |
| 167 | + builder = ProgramBuilder(mode=IrMode.ATEN) |
| 168 | + inp_proxy = builder.placeholder("inp", inp) |
| 169 | + buffer_proxy = builder.placeholder( |
| 170 | + "buffer", buffer, input_kind=InputKind.BUFFER |
| 171 | + ) |
| 172 | + add = builder.call_operator( |
| 173 | + torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy) |
| 174 | + ) |
| 175 | + builder.output([add]) |
| 176 | + program = builder.get_program() |
| 177 | + |
| 178 | + # Verify the program was created successfully |
| 179 | + self.assertEqual(len(program.graph_signature.input_specs), 2) |
| 180 | + self.assertEqual(len(program.graph_signature.output_specs), 1) |
| 181 | + self.assertEqual(builder.mode, IrMode.ATEN) |
| 182 | + |
| 183 | + def test_exir_edge_aten_add_tensor_exir_mode(self) -> None: |
| 184 | + """Test using exir_ops.edge.aten.add.Tensor with EXIR mode.""" |
| 185 | + inp = torch.randn([3, 5]) |
| 186 | + buffer = torch.randn([5]) |
| 187 | + |
| 188 | + builder_exir = ProgramBuilder(mode=IrMode.EXIR) |
| 189 | + inp_proxy_exir = builder_exir.placeholder("inp", inp) |
| 190 | + buffer_proxy_exir = builder_exir.placeholder( |
| 191 | + "buffer", buffer, input_kind=InputKind.BUFFER |
| 192 | + ) |
| 193 | + add_exir = builder_exir.call_operator( |
| 194 | + exir_ops.edge.aten.add.Tensor, (inp_proxy_exir, buffer_proxy_exir) |
| 195 | + ) |
| 196 | + builder_exir.output([add_exir]) |
| 197 | + program_exir = builder_exir.get_program() |
| 198 | + |
| 199 | + # Verify the program was created successfully |
| 200 | + self.assertEqual(len(program_exir.graph_signature.input_specs), 2) |
| 201 | + self.assertEqual(len(program_exir.graph_signature.output_specs), 1) |
| 202 | + self.assertEqual(builder_exir.mode, IrMode.EXIR) |
| 203 | + |
| 204 | + def test_exir_edge_aten_add_tensor_aten_mode(self) -> None: |
| 205 | + """Test using exir_ops.edge.aten.add.Tensor with ATEN mode.""" |
| 206 | + inp = torch.randn([3, 5]) |
| 207 | + buffer = torch.randn([5]) |
| 208 | + |
| 209 | + builder_aten = ProgramBuilder(mode=IrMode.ATEN) |
| 210 | + inp_proxy_aten = builder_aten.placeholder("inp", inp) |
| 211 | + buffer_proxy_aten = builder_aten.placeholder( |
| 212 | + "buffer", buffer, input_kind=InputKind.BUFFER |
| 213 | + ) |
| 214 | + add_aten = builder_aten.call_operator( |
| 215 | + exir_ops.edge.aten.add.Tensor, (inp_proxy_aten, buffer_proxy_aten) |
| 216 | + ) |
| 217 | + builder_aten.output([add_aten]) |
| 218 | + |
| 219 | + with self.assertRaises( |
| 220 | + SpecViolationError, msg="Operator '<EdgeOpOverload: aten.add.Tensor>" |
| 221 | + ): |
| 222 | + builder_aten.get_program() |
0 commit comments