Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def get_intermediate_path(compile_spec: List[CompileSpec]) -> str:
return None


def _get_first_delegation_tag(graph_module) -> str | None:
"""Get the first delegation tag from the graph_module or return None."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you say that tag contains partition_id thus is unique

for node in graph_module.graph.nodes:
tag = node.meta.get("delegation_tag")
if tag:
return tag

logger.debug("No delegation tag found in partition.")
return None


@final
class ArmBackend(BackendDetails):
@staticmethod
Expand Down Expand Up @@ -220,8 +231,13 @@ def preprocess( # noqa: C901
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
# access from top level.
if artifact_path is not None:
dbg_tosa_dump(tosa_graph, artifact_path)
if artifact_path:
tag = _get_first_delegation_tag(graph_module)
dbg_tosa_dump(
tosa_graph,
artifact_path,
suffix="{}".format(f"_{tag}" if tag else ""),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suffix can be _None, FYI

)

# Serialize and return the program. While we have always produced TOSA
# output as an intermediate, some flows compile to device binaries in
Expand Down
14 changes: 13 additions & 1 deletion backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import tempfile

from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -325,7 +326,18 @@ def run_tosa_ref_model(
self._has_init_run
), "RunnerUtil needs to be initialized using init_run() before running tosa reference."

desc_file_path = os.path.join(self.intermediate_path, "desc.json")
all_desc_file_paths = [
str(path) for path in Path(self.intermediate_path).glob("desc*.json")
]
assert (
all_desc_file_paths
), f"No TOSA description file found in '{self.intermediate_path}'."
if len(all_desc_file_paths) != 1:
raise NotImplementedError(
"Graphs with more than one partition are currently not supported."
)

desc_file_path = all_desc_file_paths[0]
assert os.path.exists(
desc_file_path
), f"desc_file_path: {desc_file_path} does not exist"
Expand Down
48 changes: 35 additions & 13 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from executorch.backends.xnnpack.test.tester import Tester
from executorch.exir import EdgeCompileConfig
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.lowered_backend_module import LoweredBackendModule
from torch.fx import Graph

logger = logging.getLogger(__name__)
Expand All @@ -44,21 +45,42 @@ class Partition(tester.Partition):
def dump_artifact(self, path_to_dump: Optional[str]):
super().dump_artifact(path_to_dump)

to_print = None
for spec in self.graph_module.lowered_module_0.compile_specs:
if spec.key == "output_format":
if spec.value == b"tosa":
tosa_fb = self.graph_module.lowered_module_0.processed_bytes
def get_output_format(lowered_module) -> str | None:
for spec in lowered_module.compile_specs:
if spec.key == "output_format":
return spec.value.decode()
return None

output = ""
for node in self.graph_module.graph.nodes:
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
lowered_module = getattr(self.graph_module, node.name)
assert isinstance(
lowered_module, LoweredBackendModule
), f"Attribute {node.name} must be of type LoweredBackendModule."

output_format = get_output_format(lowered_module)
if output_format == "tosa":
tosa_fb = lowered_module.processed_bytes
to_print = dbg_tosa_fb_to_json(tosa_fb)
to_print = pformat(to_print, compact=True, indent=1)
to_print = f"\n TOSA deserialized: \n{to_print}"
elif spec.value == b"vela":
vela_cmd_stream = self.graph_module.lowered_module_0.processed_bytes
to_print = str(vela_cmd_stream)
to_print = f"\n Vela command stream: \n{to_print}"
break
assert to_print is not None, "No TOSA nor Vela compile spec found"
_dump_str(to_print, path_to_dump)
output += f"\nTOSA deserialized {node.name}: \n{to_print}\n"
elif output_format == "vela":
vela_cmd_stream = lowered_module.processed_bytes
output += (
f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n"
)
else:
logger.warning(
f"No TOSA nor Vela compile spec found in compile specs of {node.name}."
)
continue

if not output:
logger.warning("No output to print generated from artifact.")
return

_dump_str(output, path_to_dump)


class Serialize(tester.Serialize):
Expand Down
10 changes: 5 additions & 5 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def dbg_node(node):


# Output TOSA flatbuffer and test harness file
def dbg_tosa_dump(tosa_graph, path):
filename = "output.tosa"
def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
filename = f"output{suffix}.tosa"

logger.info(f"Emitting debug output to {path}")
logger.info(f"Emitting debug output to: {path=}, {suffix=}")

os.makedirs(path, exist_ok=True)

Expand All @@ -63,7 +63,7 @@ def dbg_tosa_dump(tosa_graph, path):
f.write(fb)
assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer"

filepath_desc_json = os.path.join(path, "desc.json")
filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
with open(filepath_desc_json, "w") as f:
f.write(js)
assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON"
Expand All @@ -74,7 +74,7 @@ def dbg_fail(node, tosa_graph, path):
logger.warn("Internal error due to poorly handled node:")
dbg_node(node)
logger.warn(f"Debug output captured in '{path}'.")
raise RuntimeError("TOSA Internal Error on node, enable logging for further info")
raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")


# Helper function to match TOSA's broadcasting rank requirement
Expand Down
Loading