Skip to content

Commit 88cca2c

Browse files
mansnilszingo
andauthored
Arm backend: test without output re-order workaround (#15826)
By default outputs are re-ordered to correct order during TOSA lowering. However this is seen as a workaround as it should not be needed. Furthermore the output issue is not easily reproduced, rather it seems to happen randomly. Therefore we add a test case without the workaround, which is currently passing. In case it won't pass without the workaround at some point, the new changes might give some hints on why the workaround is needed and how to fix it. In case it continues to pass, we may switch the default and potentially even remove the workaround. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Co-authored-by: Zingo Andersen <[email protected]>
1 parent b91987e commit 88cca2c

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

backends/arm/common/arm_compile_spec.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,22 @@ class DebugMode(Enum):
3535
_OUTPUT_FORMAT_KEY = "output_format"
3636
_DEBUG_ARTIFACT_KEY = "debug_artifact_path"
3737
_DEBUG_MODE_KEY = "dump_debug_info"
38+
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
3839

3940
def _set_compile_specs(
4041
self,
4142
tosa_spec: TosaSpecification,
4243
compiler_flags: list[str],
4344
path_for_intermediates: str | None = None,
4445
tosa_debug_mode: DebugMode | None = None,
46+
output_order_workaround: bool = True,
4547
):
4648
"""Set all values of dataclass directly."""
4749
self.tosa_spec = tosa_spec
4850
self.compiler_flags = compiler_flags
4951
self.path_for_intermediates = path_for_intermediates
5052
self.tosa_debug_mode = tosa_debug_mode
53+
self.output_order_workaround = output_order_workaround
5154

5255
@classmethod
5356
def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
@@ -56,10 +59,15 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
5659
compiler_flags: list[str] | None = None
5760
path_for_intermediates: str | None = None
5861
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
62+
output_order_workaround: bool = True
5963
unknown_specs: dict[str, str] = {}
6064
for spec in compile_specs:
6165
key = spec.key
62-
val = spec.value.decode()
66+
val = (
67+
spec.value.decode()
68+
if isinstance(spec.value, (bytes, bytearray))
69+
else spec.value
70+
)
6371
if key == ArmCompileSpec._TOSA_SPEC_KEY:
6472
if tosa_spec is not None:
6573
raise ValueError("More than one tosa_spec entry in compile spec.")
@@ -88,6 +96,8 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
8896
"More than one tosa_debug_mode entry in compile spec."
8997
)
9098
tosa_debug_mode = ArmCompileSpec.DebugMode[val]
99+
elif key == ArmCompileSpec._OUTPUT_REORDER_KEY:
100+
output_order_workaround = val # type: ignore[assignment]
91101
else:
92102
unknown_specs[key] = val
93103

@@ -109,6 +119,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
109119
compiler_flags=compiler_flags,
110120
path_for_intermediates=path_for_intermediates,
111121
tosa_debug_mode=tosa_debug_mode,
122+
output_order_workaround=output_order_workaround,
112123
)
113124
cls.from_list_hook(compile_spec, unknown_specs)
114125
compile_spec.validate()
@@ -170,6 +181,14 @@ def to_list(self):
170181
)
171182
)
172183

184+
if not self.output_order_workaround:
185+
compile_spec.append(
186+
CompileSpec(
187+
ArmCompileSpec._OUTPUT_REORDER_KEY,
188+
self.output_order_workaround,
189+
)
190+
)
191+
173192
return compile_spec
174193

175194
def get_intermediate_path(self) -> str | None:
@@ -201,6 +220,13 @@ def dump_debug_info(self, debug_mode: DebugMode | None):
201220
self.tosa_debug_mode = debug_mode
202221
return self
203222

223+
def set_output_order_workaround(self, output_order_workaround: bool):
224+
self.output_order_workaround = output_order_workaround
225+
return self
226+
227+
def get_output_order_workaround(self) -> bool:
228+
return self.output_order_workaround
229+
204230
@classmethod
205231
@abstractmethod
206232
def get_output_format(cls) -> str:

backends/arm/test/misc/test_outputs_order.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ def _read_tosa_outputs(tosa_path: Path):
7878
return shapes
7979

8080

81+
# TODO: MLETORCH-1266 Investigate output order issue
8182
@pytest.mark.parametrize("batch_size", [1, 4])
82-
def test_network_output_order_and_restore(batch_size):
83+
@pytest.mark.parametrize("output_order_workaround", [True, False])
84+
def test_network_output_order_and_restore(batch_size, output_order_workaround):
8385
model = Network(batch_norm=True).eval()
8486
# Prepare spec
8587
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
86-
compile_spec = TosaCompileSpec(tosa_spec=spec)
88+
tosa_compile_spec = TosaCompileSpec(spec).set_output_order_workaround(
89+
output_order_workaround
90+
)
8791
# Setup quantizer
88-
quantizer = TOSAQuantizer(compile_spec)
92+
quantizer = TOSAQuantizer(tosa_compile_spec)
8993
quantizer.set_global(
9094
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
9195
)
@@ -100,7 +104,7 @@ def test_network_output_order_and_restore(batch_size):
100104
with tempfile.TemporaryDirectory(dir="") as tmpdir:
101105
art_dir = Path(tmpdir)
102106
part = TOSAPartitioner(
103-
TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir))
107+
tosa_compile_spec.dump_intermediate_artifacts_to(str(art_dir))
104108
)
105109
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
106110
# Expect exactly one .tosa file in the artefact dir

backends/arm/tosa/backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def _preprocess_module( # noqa: C901
283283
output_node.update_arg(0, [output_node.args[0]])
284284
node_to_id_map = _annotate_external_ids(graph_module.graph)
285285
artifact_path = compile_spec.get_intermediate_path()
286+
output_order_workaround = compile_spec.get_output_order_workaround()
286287

287288
# TODO: Fix the need to lazily import this.
288289
from executorch.backends.arm._passes import ArmPassManager
@@ -295,7 +296,12 @@ def _preprocess_module( # noqa: C901
295296
from executorch.backends.arm.operators.node_visitor import get_node_visitors
296297

297298
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
298-
graph_module = _sort_outputs(graph_module, node_to_id_map)
299+
300+
if output_order_workaround:
301+
logger.debug("Re-sorting outputs during TOSA lowering.")
302+
graph_module = _sort_outputs(graph_module, node_to_id_map)
303+
else:
304+
logger.debug("No re-sorting outputs (workaround) during TOSA lowering.")
299305

300306
if submodule_name is not None:
301307
tosa_graph.startRegion(submodule_name)
@@ -375,4 +381,5 @@ def filter_tosa_compile_specs(
375381
TosaCompileSpec(compile_spec.tosa_spec)
376382
.dump_intermediate_artifacts_to(compile_spec.get_intermediate_path())
377383
.dump_debug_info(compile_spec.tosa_debug_mode)
384+
.set_output_order_workaround(compile_spec.output_order_workaround)
378385
)

0 commit comments

Comments
 (0)