diff --git a/backends/arm/test/passes/test_cast_int64_pass.py b/backends/arm/test/passes/test_cast_int64_pass.py index fdfab1f3af8..0465a85deb9 100644 --- a/backends/arm/test/passes/test_cast_int64_pass.py +++ b/backends/arm/test/passes/test_cast_int64_pass.py @@ -1,17 +1,16 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest +from typing import Tuple import torch from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass -from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline -from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses +input_t = Tuple[torch.Tensor] # Input x class Int64Model(torch.nn.Module): @@ -19,26 +18,27 @@ class Int64Model(torch.nn.Module): def forward(self, x: torch.Tensor): return x + 3 - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(4),) -class TestCastInt64Pass(unittest.TestCase): - - def test_int64_model(self): - module = Int64Model() - test_pass_stage = RunPasses(passes_with_exported_program=[CastInt64ToInt32Pass]) - tester = ( - ArmTester( - module, - example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .export() - .to_edge() - .run_passes(test_pass_stage) - .run_method_and_compare_outputs() - ) - exported_program = tester.get_artifact("RunPasses").exported_program() - for state in exported_program.state_dict: - assert exported_program.state_dict[state].dtype != torch.int64 +def test_int64_model_tosa_BI(): + module = Int64Model() + op_checks = { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + } + pipeline = TestPassPipeline[input_t]( + module, + module.get_inputs(), + tosa_version="TOSA-0.80+BI", + ops_before_pass=op_checks, + ops_after_pass=op_checks, + passes_with_exported_program=[CastInt64ToInt32Pass], + ) + pipeline.pop_stage("quantize") + pipeline.run() + + exported_program = pipeline.tester.get_artifact("RunPasses").exported_program() + for state in exported_program.state_dict: + assert exported_program.state_dict[state].dtype == torch.int32 diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py index ebb96faf906..f63fa33bca1 100644 --- a/backends/arm/test/passes/test_fold_qdq_pass.py +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -1,63 +1,50 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest +from typing import Tuple import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, ) +from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import RunPasses +input_t = Tuple[torch.Tensor, torch.Tensor] # Input x, y class SimpleQuantizeModel(torch.nn.Module): def forward(self, x, y): return x + torch.max((x + x), (y + y)) - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)) -class TestFoldAndAnnotateQParamsPass(unittest.TestCase): +def test_fold_qdq_pass_tosa_BI(): """ Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into the node and stores the quantization parameters in meta. - """ - def test_fold_qdq_pass(self): - """ - Check that the pass runs for add operation and that one q node and one dq node - is removed from the representation. - """ - module = SimpleQuantizeModel() - test_pass_stage = RunPasses([FoldAndAnnotateQParamsPass]) - ( - ArmTester( - module, - example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, - } - ) - .run_passes(test_pass_stage) - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - } - ) - ) + Check that the pass runs for add operation and that one q node and one dq node + is removed from the representation. + """ + module = SimpleQuantizeModel() + pipeline = TestPassPipeline[input_t]( + module, + module.get_inputs(), + tosa_version="TOSA-0.80+BI", + ops_before_pass={ + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + }, + pass_list=[FoldAndAnnotateQParamsPass], + ) + pipeline.pop_stage(-1) # Do not compare output + pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index 09f8f578fc2..b18e536b155 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -1,15 +1,16 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest + +from typing import Tuple import torch from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses -from parameterized import parameterized +from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline + +input_t = Tuple[torch.Tensor] # Input x class MergeOneOfTwoBN(torch.nn.Module): @@ -35,7 +36,7 @@ def __init__(self, affine: bool): self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3)) self.relu6 = torch.nn.ReLU6() - def get_inputs(self) -> tuple[torch.Tensor]: + def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) def forward(self, x): @@ -72,7 +73,7 @@ def __init__(self, affine: bool): self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3)) self.relu6 = torch.nn.ReLU6() - def get_inputs(self) -> tuple[torch.Tensor]: + def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) def forward(self, x): @@ -110,7 +111,7 @@ def __init__(self, affine: bool): self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3)) self.relu6 = torch.nn.ReLU6() - def get_inputs(self) -> tuple[torch.Tensor]: + def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) def forward(self, x): @@ -126,33 +127,23 @@ def forward(self, x): return z, a -modules = [ - MergeOneOfTwoBN(True), - MergeOneOfTwoBN(False), - MergeTwosOfTwoBN(True), - MergeNoBN(True), -] - - -class TestFuseBatchnormPass(unittest.TestCase): - - @parameterized.expand(modules) - def test_fuse_batchnorm_tosa_MI(self, module): - """Test various cases where the batchnorm should and shouldn't be fused.""" - inputs = module.get_inputs() - test_pass_stage = RunPasses(passes_with_exported_program=[FuseBatchnorm2DPass]) - ( - ( - ArmTester( - module, - example_inputs=inputs, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .to_edge() - .check_count(module.ops_before_pass) - .run_passes(test_pass_stage) - .check_count(module.ops_after_pass) - .run_method_and_compare_outputs() - ) - ) +modules = { + "merge_one_of_two_bn_affine": MergeOneOfTwoBN(True), + "merge_one_of_two_bn": MergeOneOfTwoBN(False), + "merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True), + "merge_no_bn_affine": MergeNoBN(True), +} + + +@common.parametrize("module", modules) +def test_fuse_batchnorm_tosa_MI(module): + """Test various cases where the batchnorm should and shouldn't be fused.""" + pipeline = TestPassPipeline[input_t]( + module, + module.get_inputs(), + tosa_version="TOSA-0.80+MI", + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + passes_with_exported_program=[FuseBatchnorm2DPass], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_insert_table_ops_pass.py b/backends/arm/test/passes/test_insert_table_ops_pass.py index c0a9235fa6e..5c761c8bcb4 100644 --- a/backends/arm/test/passes/test_insert_table_ops_pass.py +++ b/backends/arm/test/passes/test_insert_table_ops_pass.py @@ -1,20 +1,19 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest + +from typing import Tuple import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, ) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline -from executorch.backends.arm.test import common - -from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses +input_t = Tuple[torch.Tensor] # Input x class Sigmoid(torch.nn.Module): @@ -22,34 +21,26 @@ class Sigmoid(torch.nn.Module): def forward(self, x: torch.Tensor): return x.sigmoid() - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(4),) -class TestInsertTablePass(unittest.TestCase): - - def test_insert_table_tosa_BI(self): - module = Sigmoid() - test_pass_stage = RunPasses( - [FoldAndAnnotateQParamsPass], - passes_with_exported_program=[InsertTableOpsPass], - ) - ( - ArmTester( - module, - example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .to_edge() - .run_passes(test_pass_stage) - .check("tosa._table") - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, - } - ) - .check_not(["aten_sigmoid_default"]) - ) +def test_insert_table_tosa_BI(): + module = Sigmoid() + pipeline = TestPassPipeline[input_t]( + module, + module.get_inputs(), + tosa_version="TOSA-0.80+BI", + ops_before_pass={}, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "tosa._table": 1, + }, + ops_not_after_pass=["aten_sigmoid_default"], + pass_list=[FoldAndAnnotateQParamsPass], + passes_with_exported_program=[InsertTableOpsPass], + ) + pipeline.pop_stage(-1) # Do not compare output + + pipeline.run() diff --git a/backends/arm/test/passes/test_ioquantization_pass.py b/backends/arm/test/passes/test_ioquantization_pass.py index e31007f1ed6..ecaff5e3673 100644 --- a/backends/arm/test/passes/test_ioquantization_pass.py +++ b/backends/arm/test/passes/test_ioquantization_pass.py @@ -1,10 +1,8 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest import torch @@ -24,47 +22,43 @@ def get_inputs(self): return (a, b) -class TestIOQuantizationPass(unittest.TestCase): +def test_ioquantisation_pass_u55_BI(): """ Test the executorch/exir/passes/quanize_io_pass pass works(meaning we don't get Q/DQ nodes) on a simple model """ - - def test_ioquantisation_pass(self): - model = SimpleModel() - tester = ( - ArmTester( - model, - example_inputs=model.get_inputs(), - compile_spec=common.get_u55_compile_spec(), - ) - .quantize() - .export() - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3 - } - ) - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3 - } - ) - .partition() - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2 - } - ) - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1 - } - ) + model = SimpleModel() + tester = ( + ArmTester( + model, + example_inputs=model.get_inputs(), + compile_spec=common.get_u55_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3 + } + ) + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3 + } + ) + .partition() + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2 + } ) - edge = tester.get_artifact() - edge.transform( - passes=[QuantizeInputs(edge, [0, 1]), QuantizeOutputs(edge, [0])] + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1 + } ) - tester.check_not(["edge__ops_quantized_decomposed_quantize_per_tensor"]) - tester.check_not(["edge__ops_quantized_decomposed_dequantize_per_tensor"]) + ) + edge = tester.get_artifact() + edge.transform(passes=[QuantizeInputs(edge, [0, 1]), QuantizeOutputs(edge, [0])]) + tester.check_not(["edge__ops_quantized_decomposed_quantize_per_tensor"]) + tester.check_not(["edge__ops_quantized_decomposed_dequantize_per_tensor"]) diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py index e07e91ed727..935085c66e4 100644 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -1,79 +1,78 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest + +from typing import Tuple import torch from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePoolPass, ) - from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline -from executorch.backends.xnnpack.test.tester.tester import RunPasses + +input_t = Tuple[torch.Tensor, torch.Tensor] # Input x class MeanDim(torch.nn.Module): def forward(self, x): return torch.mean(x, dim=[-1, -2], keepdim=True) - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(1, 1280, 7, 7),) + ops_before_pass = {"executorch_exir_dialects_edge__ops_aten_mean_dim": 1} + ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} + ops_not_after_pass = [ + "aten_sum_dim_int_list", + "aten_full_default", + "aten_mul_tensor", + ] + class MeanDim2(torch.nn.Module): def forward(self, x): return torch.mean(x, dim=1) - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(1, 1280, 7, 7),) + ops_before_pass = { + "aten_sum_dim_int_list": 3, + "aten_full_default": 4, + "aten_mul_tensor": 3, + } + ops_after_pass = { + "aten_sum_dim_int_list": 3, + "aten_full_default": 4, + "aten_mul_tensor": 3, + } + ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"] -class TestMeandimToAveragePool2dPass(unittest.TestCase): + +modules = { + "meandim_to_averagepool": MeanDim(), + "meandim_no_modification": MeanDim2(), +} + + +@common.parametrize("module", modules) +def test_meandim_to_avgpool_tosa_BI(module): """ Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d for the special case where dim is [-1, -2] and keepdim is True. """ - - def test_tosa_BI_meandim_to_averagepool(self): - module = MeanDim() - test_pass_stage = RunPasses([ConvertMeanDimToAveragePoolPass]) - ( - ArmTester( - module, - example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .to_edge() - .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) - .run_passes(test_pass_stage) - .check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) - ) - - def test_tosa_BI_meandim_no_modification(self): - module = MeanDim2() - test_pass_stage = RunPasses([ConvertMeanDimToAveragePoolPass]) - ( - ArmTester( - module, - example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .to_edge() - .check(["aten_sum_dim_int_list"]) - .check(["aten_full_default"]) - .check(["aten_mul_tensor"]) - .run_passes(test_pass_stage) - .check(["aten_sum_dim_int_list"]) - .check(["aten_full_default"]) - .check(["aten_mul_tensor"]) - .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) - ) + pipeline = TestPassPipeline[input_t]( + module, + module.get_inputs(), + tosa_version="TOSA-0.80+BI", + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + ops_not_after_pass=module.ops_not_after_pass, + pass_list=[ConvertMeanDimToAveragePoolPass], + ) + pipeline.pop_stage(-1) # Do not compare output + pipeline.run() diff --git a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py index 4323e7332db..8f4a9130cea 100644 --- a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py +++ b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py @@ -1,17 +1,20 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest + +from typing import Dict, Tuple import torch from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( UnsqueezeBeforeRepeatPass, ) from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import RunPasses +from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline + +input_t = Tuple[ + torch.Tensor, Dict[str, int], list[str] +] # Input x, ops_after_pass, ops_not_after_pass class Repeat(torch.nn.Module): @@ -22,53 +25,37 @@ class Repeat(torch.nn.Module): def forward(self, x: torch.Tensor): return x.repeat(2, 2, 2, 2) + test_data: Dict[str, input_t] = { + "insert_view": ( + (torch.rand((2, 3, 4)),), + {"aten_repeat_default": 3, "aten_view_copy_default": 4}, + [], + ), + "dont_insert_view": ( + (torch.rand((2, 3, 4, 1)),), + {"aten_repeat_default": 3}, + ["aten_view_copy_default"], + ), + } -class TestUnsqueezeBeforeRepeatPass(unittest.TestCase): - def test_tosa_MI_insert_view(self): - """ - When rank(input) != number of repeated dimensions (=4 in Repeat module), - insert view. - """ - module = Repeat() - inputs = (torch.rand((2, 3, 4)),) - test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) - ( - ( - ArmTester( - module, - example_inputs=inputs, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .to_edge() - .check(["aten_repeat_default"]) - .check_not(["aten_view_copy_default"]) - .run_passes(test_pass_stage) - .check(["aten_repeat_default", "aten_view_copy_default"]) - ) - ) - def test_tosa_MI_dont_insert_view(self): - """ - When rank(input) == number of repeated dimensions (=4 in Repeat module), - DON'T insert view. - """ - module = Repeat() - inputs = (torch.rand((2, 3, 4, 1)),) - test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) - ( - ( - ArmTester( - module, - example_inputs=inputs, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .to_edge() - .check(["aten_repeat_default"]) - .check_not(["aten_view_copy_default"]) - .run_passes(test_pass_stage) - .check(["aten_repeat_default"]) - .check_not(["aten_view_copy_default"]) - ) - ) +@common.parametrize("test_data", Repeat.test_data) +def test_unsqueeze_before_repeat_tosa_MI(test_data): + """ + When rank(input) != number of repeated dimensions (=4 in Repeat module), + insert view. + """ + module = Repeat() + data, ops_after_pass, ops_not_after_pass = test_data + pipeline = TestPassPipeline( + module, + data, + tosa_version="TOSA-0.80+MI", + ops_before_pass={"aten_repeat_default": 3}, + ops_not_before_pass=["aten_view_copy_default"], + ops_after_pass=ops_after_pass, + ops_not_after_pass=ops_not_after_pass, + pass_list=[UnsqueezeBeforeRepeatPass], + ) + pipeline.pop_stage(-1) # Do not compare output + pipeline.run() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index b25773d604c..0f079b3a6fd 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -4,13 +4,14 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Callable, Generic, List, TypeVar +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses from executorch.exir.backend.compile_spec_schema import CompileSpec - +from executorch.exir.pass_base import ExportPass +from torch._export.pass_base import PassType logger = logging.getLogger(__name__) T = TypeVar("T") @@ -430,3 +431,79 @@ def __init__( qtol=1, inputs=self.test_data, ) + + +class TestPassPipeline(BasePipelineMaker, Generic[T]): + """ + Runs single passes directly on an edge_program and checks operators before/after. + + Attributes: + module: The module which the pipeline is applied to. + test_data: Data used for quantizing and testing the module. + tosa_version: The TOSA-version which to test for. + + ops_before_pass : Ops expected to be found in the graph before passes. + ops_not_before_pass : Ops expected not to be found in the graph before passes. + ops_after_pass : Ops expected to be found in the graph after passes. + ops_notafter_pass : Ops expected not to be found in the graph after passes. + + pass_list: List of regular passes. + pass_functions: List of functions applied directly to the exported program. + passes_with_exported_program: List of passes initiated with an exported_program. + + Passes are run in order pass_list -> pass_functions -> passes_with_exported_program. + See arm_tester.RunPasses() for more information. + """ + + def __init__( + self, + module: torch.nn.Module, + test_data: T, + tosa_version: str, + ops_before_pass: Optional[Dict[str, int]] = None, + ops_not_before_pass: Optional[list[str]] = None, + ops_after_pass: Optional[Dict[str, int]] = None, + ops_not_after_pass: Optional[list[str]] = None, + pass_list: Optional[List[Type[PassType]]] = None, + pass_functions: Optional[List[Callable]] = None, + passes_with_exported_program: Optional[List[Type[ExportPass]]] = None, + ): + compile_spec = common.get_tosa_compile_spec( + tosa_version, + ) + super().__init__( + module, + test_data, + None, + None, + compile_spec, + use_to_edge_transform_and_lower=False, + ) + + # Delete most of the pipeline + self.pop_stage("check.exir") + self.pop_stage("partition") + self.pop_stage("check_not.exir") + self.pop_stage("check_count.exir") + self.pop_stage("to_executorch") + self.pop_stage("check.aten") + + if "BI" in tosa_version: + self.add_stage(self.tester.quantize, pos=0) + + # Add checks/check_not's if given + if ops_before_pass: + self.add_stage(self.tester.check_count, ops_before_pass, suffix="before") + if ops_not_before_pass: + self.add_stage(self.tester.check_not, ops_not_before_pass, suffix="before") + test_pass_stage = RunPasses( + pass_list, pass_functions, passes_with_exported_program + ) + + self.add_stage(self.tester.run_passes, test_pass_stage) + + if ops_after_pass: + self.add_stage(self.tester.check_count, ops_after_pass, suffix="after") + if ops_not_after_pass: + self.add_stage(self.tester.check_not, ops_not_after_pass, suffix="after") + self.add_stage(self.tester.run_method_and_compare_outputs)