diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 3b1e6ce86b5..aae4f3e6967 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -310,3 +310,14 @@ def _clean_dir(dir: Path, filter: str, num_save=10): for remove in sorted_files[0 : len(sorted_files) - num_save]: file = remove[1] file.unlink() + + +def get_target_board(compile_spec: list[CompileSpec]) -> str | None: + for spec in compile_spec: + if spec.key == "compile_flags": + flags = spec.value.decode() + if "u55" in flags: + return "corstone-300" + elif "u85" in flags: + return "corstone-320" + return None diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 66e278ee0f1..6676a38addb 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -115,6 +115,8 @@ def _test_add_ethos_BI_pipeline( .to_executorch() .serialize() ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) return tester @@ -131,28 +133,20 @@ def test_add_tosa_BI(self, test_data: torch.Tensor): @parameterized.expand(Add.test_parameters) def test_add_u55_BI(self, test_data: torch.Tensor): test_data = (test_data,) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add(), common.get_u55_compile_spec(permute_memory_to_nhwc=True), test_data, ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-300" - ) @parameterized.expand(Add.test_parameters) def test_add_u85_BI(self, test_data: torch.Tensor): test_data = (test_data,) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add(), common.get_u85_compile_spec(permute_memory_to_nhwc=True), test_data, ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-320" - ) @parameterized.expand(Add2.test_parameters) def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): @@ -167,21 +161,13 @@ def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): @parameterized.expand(Add2.test_parameters) def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add2(), common.get_u55_compile_spec(), test_data ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-300" - ) @parameterized.expand(Add2.test_parameters) def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) - tester = self._test_add_ethos_BI_pipeline( + self._test_add_ethos_BI_pipeline( self.Add2(), common.get_u85_compile_spec(), test_data ) - if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-320" - ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 608761098e0..5940067af62 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -191,9 +191,6 @@ def init_run( target_board: str, ): - if target_board not in ["corstone-300", "corstone-320"]: - raise RuntimeError(f"Unknown target board: {target_board}") - self.input_names = _get_input_names(edge_program) self.output_node = _get_output_node(exported_program) self.output_name = self.output_node.name @@ -222,6 +219,8 @@ def run_corstone( assert ( self._has_init_run ), "RunnerUtil needs to be initialized using init_run() before running Corstone300." + if self.target_board not in ["corstone-300", "corstone-320"]: + raise RuntimeError(f"Unknown target board: {self.target_board}") pte_path = os.path.join(self.intermediate_path, "program.pte") assert os.path.exists(pte_path), f"Pte path '{pte_path}' not found." diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index e2062f24287..3564a3325a6 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -26,6 +26,7 @@ arm_test_options, current_time_formated, get_option, + get_target_board, ) from executorch.backends.arm.test.runner_utils import ( @@ -267,7 +268,7 @@ def run_method_and_compare_outputs( self, inputs: Optional[Tuple[torch.Tensor]] = None, stage: Optional[str] = None, - target_board: Optional[str] = "corstone-300", + target_board: Optional[str] = None, num_runs=1, atol=1e-03, rtol=1e-03, @@ -301,6 +302,9 @@ def run_method_and_compare_outputs( test_stage = self.stages[stage] is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None + if target_board is None: + target_board = get_target_board(self.compile_spec) + exported_program = self.stages[self.stage_name(tester.Export)].artifact edge_program = edge_stage.artifact.exported_program() self.runner_util.init_run(