diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 00eb395be9f..1e0c21239e2 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -13,7 +13,7 @@ import torch import torch.fx -from executorch.backends.arm.common.debug import get_node_debug_info +from executorch.backends.arm.tosa_utils import get_node_debug_info from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops diff --git a/backends/arm/common/__init__.py b/backends/arm/common/__init__.py deleted file mode 100644 index c8d1c683da3..00000000000 --- a/backends/arm/common/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/backends/arm/common/debug.py b/backends/arm/common/debug.py deleted file mode 100644 index bca6c06d140..00000000000 --- a/backends/arm/common/debug.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import os -from typing import Optional - -import serializer.tosa_serializer as ts # type: ignore -import torch -from executorch.exir.print_program import inspect_node - -logger = logging.getLogger(__name__) - - -def debug_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): - # Debug output of node information - logger.info(get_node_debug_info(node, graph_module)) - - -def get_node_debug_info( - node: torch.fx.Node, graph_module: torch.fx.GraphModule | None = None -) -> str: - output = ( - f" {inspect_node(graph=graph_module.graph, node=node)}\n" - if graph_module - else "" - "-- NODE DEBUG INFO --\n" - f" Op is {node.op}\n" - f" Name is {node.name}\n" - f" Node target is {node.target}\n" - f" Node args is {node.args}\n" - f" Node kwargs is {node.kwargs}\n" - f" Node users is {node.users}\n" - " Node.meta = \n" - ) - for k, v in node.meta.items(): - if k == "stack_trace": - matches = v.split("\n") - output += " 'stack_trace =\n" - for m in matches: - output += f" {m}\n" - else: - output += f" '{k}' = {v}\n" - - if isinstance(v, list): - for i in v: - output += f" {i}\n" - return output - - -# Output TOSA flatbuffer and test harness file -def debug_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): - filename = f"output{suffix}.tosa" - - logger.info(f"Emitting debug output to: {path=}, {suffix=}") - - os.makedirs(path, exist_ok=True) - - fb = tosa_graph.serialize() - js = tosa_graph.writeJson(filename) - - filepath_tosa_fb = os.path.join(path, filename) - with open(filepath_tosa_fb, "wb") as f: - f.write(fb) - if not os.path.exists(filepath_tosa_fb): - raise IOError("Failed to write TOSA flatbuffer") - - filepath_desc_json = os.path.join(path, f"desc{suffix}.json") - with open(filepath_desc_json, "w") as f: - f.write(js) - if not os.path.exists(filepath_desc_json): - raise IOError("Failed to write TOSA JSON") - - -def debug_fail( - node, - graph_module, - tosa_graph: Optional[ts.TosaSerializer] = None, - path: Optional[str] = None, -): - logger.warning("Internal error due to poorly handled node:") - if tosa_graph is not None and path is not None: - debug_tosa_dump(tosa_graph, path) - logger.warning(f"Debug output captured in '{path}'.") - debug_node(node, graph_module) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index a6f5671a881..bbb5cdc373a 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -11,8 +11,8 @@ import torch import torch.fx import torch.nn.functional as F -from executorch.backends.arm.common.debug import get_node_debug_info from executorch.backends.arm.quantizer import QuantizationConfig +from executorch.backends.arm.tosa_utils import get_node_debug_info from torch._subclasses import FakeTensor from torch.fx import Node diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index 7062d68b944..1211261c23b 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -19,12 +19,12 @@ from executorch.backends.arm._passes import ( ArmPassManager, ) # usort: skip -from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.process_node import ( process_call_function, process_output, process_placeholder, ) +from executorch.backends.arm.tosa_utils import dbg_fail, dbg_tosa_dump from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram @@ -115,12 +115,12 @@ def preprocess( # noqa: C901 # any checking of compatibility. raise RuntimeError(f"{node.name} is unsupported op {node.op}") except Exception: - debug_fail(node, graph_module, tosa_graph, artifact_path) + dbg_fail(node, graph_module, tosa_graph, artifact_path) raise if artifact_path: tag = arm_get_first_delegation_tag(graph_module) - debug_tosa_dump( + dbg_tosa_dump( tosa_graph, artifact_path, suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 7d544e46bfc..bc495b12294 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -6,7 +6,8 @@ # pyre-unsafe import logging -from typing import Any +import os +from typing import Any, Optional import numpy as np import serializer.tosa_serializer as ts # type: ignore @@ -19,6 +20,7 @@ from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.print_program import inspect_node from torch._subclasses.fake_tensor import FakeTensor from torch.fx import Node @@ -26,6 +28,77 @@ logger = logging.getLogger(__name__) +def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): + # Debug output of node information + logger.info(get_node_debug_info(node, graph_module)) + + +def get_node_debug_info( + node: torch.fx.Node, graph_module: torch.fx.GraphModule | None = None +) -> str: + output = ( + f" {inspect_node(graph=graph_module.graph, node=node)}\n" + if graph_module + else "" + "-- NODE DEBUG INFO --\n" + f" Op is {node.op}\n" + f" Name is {node.name}\n" + f" Node target is {node.target}\n" + f" Node args is {node.args}\n" + f" Node kwargs is {node.kwargs}\n" + f" Node users is {node.users}\n" + " Node.meta = \n" + ) + for k, v in node.meta.items(): + if k == "stack_trace": + matches = v.split("\n") + output += " 'stack_trace =\n" + for m in matches: + output += f" {m}\n" + else: + output += f" '{k}' = {v}\n" + + if isinstance(v, list): + for i in v: + output += f" {i}\n" + return output + + +# Output TOSA flatbuffer and test harness file +def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): + filename = f"output{suffix}.tosa" + + logger.info(f"Emitting debug output to: {path=}, {suffix=}") + + os.makedirs(path, exist_ok=True) + + fb = tosa_graph.serialize() + js = tosa_graph.writeJson(filename) + + filepath_tosa_fb = os.path.join(path, filename) + with open(filepath_tosa_fb, "wb") as f: + f.write(fb) + assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer" + + filepath_desc_json = os.path.join(path, f"desc{suffix}.json") + with open(filepath_desc_json, "w") as f: + f.write(js) + assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" + + +def dbg_fail( + node, + graph_module, + tosa_graph: Optional[ts.TosaSerializer] = None, + path: Optional[str] = None, +): + logger.warning("Internal error due to poorly handled node:") + if tosa_graph is not None and path is not None: + dbg_tosa_dump(tosa_graph, path) + logger.warning(f"Debug output captured in '{path}'.") + dbg_node(node, graph_module) + + def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]: try: return [TosaArg(arg, tosa_spec) for arg in node.args]