Skip to content

Commit 8d4bc9b

Browse files
Eashan Gargfacebook-github-bot
authored andcommitted
Conv Combos with Simulator (#6977)
Summary: Conv Combos that fail in simulator Differential Revision: D66217464
1 parent 82763a9 commit 8d4bc9b

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

backends/arm/test/ops/test_conv_combos.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,33 @@ def _test_conv_combo_ethos_BI_pipeline(
256256
if conftest.is_option_enabled("corstone_fvp"):
257257
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
258258

259+
def _test_conv_combo_ethos_MI_pipeline(
260+
self,
261+
module: torch.nn.Module,
262+
compile_spec: CompileSpec,
263+
test_data: Tuple[torch.Tensor],
264+
atol: float = 1e-3,
265+
rtol: float = 1e-3,
266+
):
267+
(
268+
ArmTester(
269+
module,
270+
example_inputs=test_data,
271+
compile_spec=compile_spec,
272+
)
273+
.quantize()
274+
.export()
275+
.to_edge()
276+
.partition()
277+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
278+
.check_not(list(module.edge_op_list))
279+
.to_executorch()
280+
.serialize()
281+
.run_method_and_compare_outputs(
282+
inputs=test_data, atol=atol, rtol=rtol, qtol=1
283+
)
284+
)
285+
259286
####################
260287
## Conv + meandim ##
261288
####################
@@ -276,6 +303,15 @@ def test_conv_meandim_u55_BI(self):
276303
model.get_inputs(),
277304
)
278305

306+
@pytest.mark.corstone_fvp
307+
def test_conv_meandim_u55_MI(self):
308+
model = ComboConv2dMeandim()
309+
self._test_conv_combo_ethos_MI_pipeline(
310+
model,
311+
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
312+
model.get_inputs(),
313+
)
314+
279315
@pytest.mark.corstone_fvp
280316
def test_conv_meandim_u85_BI(self):
281317
model = ComboConv2dMeandim()
@@ -303,6 +339,13 @@ def test_conv_batchnorm_relu6_u55_BI(self):
303339
model, common.get_u55_compile_spec(), model.get_inputs()
304340
)
305341

342+
@pytest.mark.corstone_fvp
343+
def test_conv_batchnorm_relu6_u55_MI(self):
344+
model = ComboConvBatchnormRelu6()
345+
self._test_conv_combo_ethos_MI_pipeline(
346+
model, common.get_u55_compile_spec(), model.get_inputs()
347+
)
348+
306349
@pytest.mark.corstone_fvp
307350
def test_conv_batchnorm_relu_u85_BI(self):
308351
model = ComboConvBatchnormRelu6()
@@ -336,6 +379,14 @@ def test_conv_relu6_u55_BI(self, test_data: torch.Tensor):
336379
model, common.get_u55_compile_spec(permute_memory_to_nhwc=True), test_data
337380
)
338381

382+
@parameterized.expand(ComboConvRelu6.test_data)
383+
def test_conv_relu6_u55_MI(self, test_data: torch.Tensor):
384+
model = ComboConvRelu6()
385+
test_data = (test_data,)
386+
self._test_conv_combo_ethos_MI_pipeline(
387+
model, common.get_u55_compile_spec(permute_memory_to_nhwc=True), test_data
388+
)
389+
339390
@parameterized.expand(ComboConvRelu6.test_data)
340391
@pytest.mark.corstone_fvp
341392
def test_conv_relu6_u85_BI(self, test_data: torch.Tensor):
@@ -367,6 +418,15 @@ def test_block_bottleneck_residual_u55_BI(self):
367418
model.get_inputs(),
368419
)
369420

421+
@pytest.mark.corstone_fvp
422+
def test_block_bottleneck_residual_u55_MI(self):
423+
model = ComboBlockBottleneckResidual()
424+
self._test_conv_combo_ethos_MI_pipeline(
425+
model,
426+
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
427+
model.get_inputs(),
428+
)
429+
370430
@pytest.mark.corstone_fvp
371431
def test_block_bottleneck_residual_u85_BI(self):
372432
model = ComboBlockBottleneckResidual()
@@ -402,6 +462,16 @@ def test_conv_avgpool2d_u55_BI(self, test_data: torch.Tensor):
402462
test_data,
403463
)
404464

465+
@parameterized.expand(ComboConvAvgPool2d.test_data)
466+
def test_conv_avgpool2d_u55_MI(self, test_data: torch.Tensor):
467+
model = ComboConvAvgPool2d()
468+
test_data = (test_data,)
469+
self._test_conv_combo_ethos_MI_pipeline(
470+
model,
471+
common.get_u55_compile_spec(),
472+
test_data,
473+
)
474+
405475
@parameterized.expand(ComboConvAvgPool2d.test_data)
406476
@pytest.mark.corstone_fvp
407477
def test_conv_avgpool2d_u85_BI(self, test_data: torch.Tensor):

0 commit comments

Comments
 (0)