diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py index a33531e2411..1075594e901 100644 --- a/backends/arm/common/arm_compile_spec.py +++ b/backends/arm/common/arm_compile_spec.py @@ -35,6 +35,7 @@ class DebugMode(Enum): _OUTPUT_FORMAT_KEY = "output_format" _DEBUG_ARTIFACT_KEY = "debug_artifact_path" _DEBUG_MODE_KEY = "dump_debug_info" + _OUTPUT_REORDER_KEY = "ouput_reorder_workaround" def _set_compile_specs( self, @@ -42,12 +43,14 @@ def _set_compile_specs( compiler_flags: list[str], path_for_intermediates: str | None = None, tosa_debug_mode: DebugMode | None = None, + output_order_workaround: bool = True, ): """Set all values of dataclass directly.""" self.tosa_spec = tosa_spec self.compiler_flags = compiler_flags self.path_for_intermediates = path_for_intermediates self.tosa_debug_mode = tosa_debug_mode + self.output_order_workaround = output_order_workaround @classmethod def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 @@ -56,10 +59,15 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 compiler_flags: list[str] | None = None path_for_intermediates: str | None = None tosa_debug_mode: ArmCompileSpec.DebugMode | None = None + output_order_workaround: bool = True unknown_specs: dict[str, str] = {} for spec in compile_specs: key = spec.key - val = spec.value.decode() + val = ( + spec.value.decode() + if isinstance(spec.value, (bytes, bytearray)) + else spec.value + ) if key == ArmCompileSpec._TOSA_SPEC_KEY: if tosa_spec is not None: raise ValueError("More than one tosa_spec entry in compile spec.") @@ -88,6 +96,8 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 "More than one tosa_debug_mode entry in compile spec." ) tosa_debug_mode = ArmCompileSpec.DebugMode[val] + elif key == ArmCompileSpec._OUTPUT_REORDER_KEY: + output_order_workaround = val # type: ignore[assignment] else: unknown_specs[key] = val @@ -109,6 +119,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 compiler_flags=compiler_flags, path_for_intermediates=path_for_intermediates, tosa_debug_mode=tosa_debug_mode, + output_order_workaround=output_order_workaround, ) cls.from_list_hook(compile_spec, unknown_specs) compile_spec.validate() @@ -170,6 +181,14 @@ def to_list(self): ) ) + if not self.output_order_workaround: + compile_spec.append( + CompileSpec( + ArmCompileSpec._OUTPUT_REORDER_KEY, + self.output_order_workaround, + ) + ) + return compile_spec def get_intermediate_path(self) -> str | None: @@ -201,6 +220,13 @@ def dump_debug_info(self, debug_mode: DebugMode | None): self.tosa_debug_mode = debug_mode return self + def set_output_order_workaround(self, output_order_workaround: bool): + self.output_order_workaround = output_order_workaround + return self + + def get_output_order_workaround(self) -> bool: + return self.output_order_workaround + @classmethod @abstractmethod def get_output_format(cls) -> str: diff --git a/backends/arm/test/misc/test_outputs_order.py b/backends/arm/test/misc/test_outputs_order.py index cada9e89922..253888537f8 100644 --- a/backends/arm/test/misc/test_outputs_order.py +++ b/backends/arm/test/misc/test_outputs_order.py @@ -78,14 +78,18 @@ def _read_tosa_outputs(tosa_path: Path): return shapes +# TODO: MLETORCH-1266 Investigate output order issue @pytest.mark.parametrize("batch_size", [1, 4]) -def test_network_output_order_and_restore(batch_size): +@pytest.mark.parametrize("output_order_workaround", [True, False]) +def test_network_output_order_and_restore(batch_size, output_order_workaround): model = Network(batch_norm=True).eval() # Prepare spec spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = TosaCompileSpec(tosa_spec=spec) + tosa_compile_spec = TosaCompileSpec(spec).set_output_order_workaround( + output_order_workaround + ) # Setup quantizer - quantizer = TOSAQuantizer(compile_spec) + quantizer = TOSAQuantizer(tosa_compile_spec) quantizer.set_global( get_symmetric_quantization_config(is_qat=True, is_per_channel=False) ) @@ -100,7 +104,7 @@ def test_network_output_order_and_restore(batch_size): with tempfile.TemporaryDirectory(dir="") as tmpdir: art_dir = Path(tmpdir) part = TOSAPartitioner( - TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir)) + tosa_compile_spec.dump_intermediate_artifacts_to(str(art_dir)) ) _ = to_edge_transform_and_lower(aten_gm, partitioner=[part]) # Expect exactly one .tosa file in the artefact dir diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index e643a95eecc..99fcadac081 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -283,6 +283,7 @@ def _preprocess_module( # noqa: C901 output_node.update_arg(0, [output_node.args[0]]) node_to_id_map = _annotate_external_ids(graph_module.graph) artifact_path = compile_spec.get_intermediate_path() + output_order_workaround = compile_spec.get_output_order_workaround() # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager @@ -295,7 +296,12 @@ def _preprocess_module( # noqa: C901 from executorch.backends.arm.operators.node_visitor import get_node_visitors node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) - graph_module = _sort_outputs(graph_module, node_to_id_map) + + if output_order_workaround: + logger.debug("Re-sorting outputs during TOSA lowering.") + graph_module = _sort_outputs(graph_module, node_to_id_map) + else: + logger.debug("No re-sorting outputs (workaround) during TOSA lowering.") if submodule_name is not None: tosa_graph.startRegion(submodule_name) @@ -375,4 +381,5 @@ def filter_tosa_compile_specs( TosaCompileSpec(compile_spec.tosa_spec) .dump_intermediate_artifacts_to(compile_spec.get_intermediate_path()) .dump_debug_info(compile_spec.tosa_debug_mode) + .set_output_order_workaround(compile_spec.output_order_workaround) )