Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion backends/arm/common/arm_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ class DebugMode(Enum):
_OUTPUT_FORMAT_KEY = "output_format"
_DEBUG_ARTIFACT_KEY = "debug_artifact_path"
_DEBUG_MODE_KEY = "dump_debug_info"
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"

def _set_compile_specs(
self,
tosa_spec: TosaSpecification,
compiler_flags: list[str],
path_for_intermediates: str | None = None,
tosa_debug_mode: DebugMode | None = None,
output_order_workaround: bool = True,
):
"""Set all values of dataclass directly."""
self.tosa_spec = tosa_spec
self.compiler_flags = compiler_flags
self.path_for_intermediates = path_for_intermediates
self.tosa_debug_mode = tosa_debug_mode
self.output_order_workaround = output_order_workaround

@classmethod
def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
Expand All @@ -56,10 +59,15 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
compiler_flags: list[str] | None = None
path_for_intermediates: str | None = None
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
output_order_workaround: bool = True
unknown_specs: dict[str, str] = {}
for spec in compile_specs:
key = spec.key
val = spec.value.decode()
val = (
spec.value.decode()
if isinstance(spec.value, (bytes, bytearray))
else spec.value
)
if key == ArmCompileSpec._TOSA_SPEC_KEY:
if tosa_spec is not None:
raise ValueError("More than one tosa_spec entry in compile spec.")
Expand Down Expand Up @@ -88,6 +96,8 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
"More than one tosa_debug_mode entry in compile spec."
)
tosa_debug_mode = ArmCompileSpec.DebugMode[val]
elif key == ArmCompileSpec._OUTPUT_REORDER_KEY:
output_order_workaround = val # type: ignore[assignment]
else:
unknown_specs[key] = val

Expand All @@ -109,6 +119,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
compiler_flags=compiler_flags,
path_for_intermediates=path_for_intermediates,
tosa_debug_mode=tosa_debug_mode,
output_order_workaround=output_order_workaround,
)
cls.from_list_hook(compile_spec, unknown_specs)
compile_spec.validate()
Expand Down Expand Up @@ -170,6 +181,14 @@ def to_list(self):
)
)

if not self.output_order_workaround:
compile_spec.append(
CompileSpec(
ArmCompileSpec._OUTPUT_REORDER_KEY,
self.output_order_workaround,
)
)

return compile_spec

def get_intermediate_path(self) -> str | None:
Expand Down Expand Up @@ -201,6 +220,13 @@ def dump_debug_info(self, debug_mode: DebugMode | None):
self.tosa_debug_mode = debug_mode
return self

def set_output_order_workaround(self, output_order_workaround: bool):
self.output_order_workaround = output_order_workaround
return self

def get_output_order_workaround(self) -> bool:
return self.output_order_workaround

@classmethod
@abstractmethod
def get_output_format(cls) -> str:
Expand Down
12 changes: 8 additions & 4 deletions backends/arm/test/misc/test_outputs_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,18 @@ def _read_tosa_outputs(tosa_path: Path):
return shapes


# TODO: MLETORCH-1266 Investigate output order issue
@pytest.mark.parametrize("batch_size", [1, 4])
def test_network_output_order_and_restore(batch_size):
@pytest.mark.parametrize("output_order_workaround", [True, False])
def test_network_output_order_and_restore(batch_size, output_order_workaround):
model = Network(batch_norm=True).eval()
# Prepare spec
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
compile_spec = TosaCompileSpec(tosa_spec=spec)
tosa_compile_spec = TosaCompileSpec(spec).set_output_order_workaround(
output_order_workaround
)
# Setup quantizer
quantizer = TOSAQuantizer(compile_spec)
quantizer = TOSAQuantizer(tosa_compile_spec)
quantizer.set_global(
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
)
Expand All @@ -100,7 +104,7 @@ def test_network_output_order_and_restore(batch_size):
with tempfile.TemporaryDirectory(dir="") as tmpdir:
art_dir = Path(tmpdir)
part = TOSAPartitioner(
TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir))
tosa_compile_spec.dump_intermediate_artifacts_to(str(art_dir))
)
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
# Expect exactly one .tosa file in the artefact dir
Expand Down
9 changes: 8 additions & 1 deletion backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def _preprocess_module( # noqa: C901
output_node.update_arg(0, [output_node.args[0]])
node_to_id_map = _annotate_external_ids(graph_module.graph)
artifact_path = compile_spec.get_intermediate_path()
output_order_workaround = compile_spec.get_output_order_workaround()

# TODO: Fix the need to lazily import this.
from executorch.backends.arm._passes import ArmPassManager
Expand All @@ -295,7 +296,12 @@ def _preprocess_module( # noqa: C901
from executorch.backends.arm.operators.node_visitor import get_node_visitors

node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
graph_module = _sort_outputs(graph_module, node_to_id_map)

if output_order_workaround:
logger.debug("Re-sorting outputs during TOSA lowering.")
graph_module = _sort_outputs(graph_module, node_to_id_map)
else:
logger.debug("No re-sorting outputs (workaround) during TOSA lowering.")

if submodule_name is not None:
tosa_graph.startRegion(submodule_name)
Expand Down Expand Up @@ -375,4 +381,5 @@ def filter_tosa_compile_specs(
TosaCompileSpec(compile_spec.tosa_spec)
.dump_intermediate_artifacts_to(compile_spec.get_intermediate_path())
.dump_debug_info(compile_spec.tosa_debug_mode)
.set_output_order_workaround(compile_spec.output_order_workaround)
)
Loading