Skip to content

Commit 0cca3ca

Browse files
authored
Arm backend: Add schema for debug data capture (#13868)
* Add DebugSchema defining debugging data format * Defines DebugHook used to capture TOSA serialization events * Add tests for DebugHook
1 parent 0b0e2dc commit 0cca3ca

File tree

3 files changed

+305
-0
lines changed

3 files changed

+305
-0
lines changed

backends/arm/debug/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.

backends/arm/debug/schema.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import json
9+
10+
from dataclasses import asdict, dataclass
11+
from typing import Any
12+
13+
import serializer.tosa_serializer as ts # type: ignore
14+
import torch
15+
16+
from torch.fx.traceback import NodeSource
17+
18+
19+
@dataclass
20+
class TosaDebugSchema:
21+
node_name: str
22+
operator_name: str
23+
operator_id: int
24+
25+
26+
@dataclass
27+
class ATenDebugSchema:
28+
node_name: str
29+
operator_name: str
30+
31+
@staticmethod
32+
def from_node(node: torch.fx.Node) -> ATenDebugSchema:
33+
# node.target is Union[Callable[..., Any], str], so we need to access this correctly depending on the type
34+
if callable(node.target):
35+
operator_name = node.target.__name__
36+
else:
37+
operator_name = node.target
38+
39+
return ATenDebugSchema(node_name=node.name, operator_name=operator_name)
40+
41+
42+
@dataclass
43+
class TorchDebugSchema:
44+
stack_trace: list[str]
45+
node_trace: list[dict[str, Any]] | str
46+
nn_module_stack: dict[str, Any] | str
47+
torch_fn: tuple[str, str] | str
48+
49+
@staticmethod
50+
def serialize_node_trace(node_trace: list[NodeSource]) -> list[dict[str, Any]]:
51+
"""Flatten the from_node dictionary to remove nesting."""
52+
flattened = []
53+
node_stack = []
54+
55+
for n in node_trace:
56+
node_stack.append((n, -1))
57+
58+
while len(node_stack) > 0:
59+
node, parent_id = node_stack.pop()
60+
flattened.append(
61+
{
62+
"name": node.name,
63+
"target": node.target,
64+
"graph_id": node.graph_id,
65+
"pass_name": node.pass_name,
66+
"action": node._get_action_string(),
67+
"parent_graph_id": parent_id,
68+
}
69+
)
70+
71+
for n in node.from_node:
72+
node_stack.append((n, node.graph_id))
73+
74+
return flattened
75+
76+
@staticmethod
77+
def from_node(node: torch.fx.Node) -> TorchDebugSchema:
78+
node_trace: str | list[dict[str, Any]] = "No node trace available."
79+
80+
if "from_node" in node.meta:
81+
# Flatten the node_trace dictionary, so there is no nesting
82+
node_trace = TorchDebugSchema.serialize_node_trace(node.meta["from_node"])
83+
84+
return TorchDebugSchema(
85+
stack_trace=node.meta.get("stack_trace", "No stack trace available").split(
86+
"\n"
87+
),
88+
node_trace=node_trace,
89+
nn_module_stack=node.meta.get(
90+
"nn_module_stack", "No module stack trace available"
91+
),
92+
torch_fn=node.meta.get("torch_fn", "No torch_fn available"),
93+
)
94+
95+
96+
@dataclass
97+
class DebugSchema:
98+
event_id: int
99+
aten_info: ATenDebugSchema
100+
tosa_info: TosaDebugSchema
101+
torch_info: TorchDebugSchema
102+
103+
104+
class DebugHook:
105+
def __init__(self) -> None:
106+
self._debug_events: list[DebugSchema] = []
107+
self.__op_id_to_name = {}
108+
109+
# Build up a mapping from TOSA 1.0 operator IDs to their names
110+
for name, val in vars(ts.Op).items():
111+
self.__op_id_to_name[val] = name
112+
113+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None:
114+
tosa_debug_info = TosaDebugSchema(
115+
node_name=str(tosa_op),
116+
operator_name=self.__op_id_to_name[tosa_op_id],
117+
operator_id=tosa_op_id,
118+
)
119+
120+
aten_debug_info = ATenDebugSchema.from_node(node)
121+
torch_debug_info = TorchDebugSchema.from_node(node)
122+
123+
self._debug_events.append(
124+
DebugSchema(
125+
event_id=len(self._debug_events),
126+
aten_info=aten_debug_info,
127+
tosa_info=tosa_debug_info,
128+
torch_info=torch_debug_info,
129+
)
130+
)
131+
132+
def serialize(self) -> str:
133+
return json.dumps([asdict(event) for event in self._debug_events], indent=4)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from dataclasses import dataclass
7+
from types import SimpleNamespace
8+
9+
from executorch.backends.arm.debug.schema import DebugHook, DebugSchema
10+
from executorch.backends.arm.test import common
11+
12+
13+
@dataclass
14+
class DebugHookTestCase:
15+
mock_node: SimpleNamespace
16+
tosa_op: str
17+
op_id: int
18+
expected_events: int
19+
num_nodes_traced: int
20+
21+
22+
def create_mock_node_1():
23+
def _get_action_str() -> str:
24+
return "create"
25+
26+
from_node_2 = SimpleNamespace(
27+
name="convolution",
28+
target="aten.convolution.default",
29+
graph_id=6052414368,
30+
pass_name="ExportedProgram.module()",
31+
action="create",
32+
from_node=[],
33+
_get_action_string=_get_action_str,
34+
)
35+
36+
from_node_1 = SimpleNamespace(
37+
name="convolution",
38+
target="aten.convolution.default",
39+
graph_id=5705954832,
40+
pass_name="Interpreter_PropagateUnbackedSymInts",
41+
action="create",
42+
from_node=[from_node_2],
43+
_get_action_string=_get_action_str,
44+
)
45+
46+
fx_node_mock = SimpleNamespace(
47+
name="aten_convolution_default",
48+
target="aten.convolution.default",
49+
meta={
50+
"stack_trace": 'File "models/model.py", line 221, in forward\nreturn self.features(x)',
51+
"nn_module_stack": {"__self__": ["", "model.Model"]},
52+
"torch_fn": ("conv2d", "builtin_function_or_method.conv2d"),
53+
"from_node": [from_node_1],
54+
},
55+
)
56+
57+
return fx_node_mock
58+
59+
60+
def create_mock_node_2():
61+
def _get_action_str() -> str:
62+
return "create"
63+
64+
from_node_1 = SimpleNamespace(
65+
name="convolution",
66+
target="aten.convolution.default",
67+
graph_id=5705954832,
68+
pass_name="Interpreter_PropagateUnbackedSymInts",
69+
action="create",
70+
from_node=[],
71+
_get_action_string=_get_action_str,
72+
)
73+
74+
fx_node_mock = SimpleNamespace(
75+
name="aten_convolution_default",
76+
target="aten.convolution.default",
77+
meta={
78+
"from_node": [from_node_1],
79+
},
80+
)
81+
82+
return fx_node_mock
83+
84+
85+
def create_mock_node_3():
86+
fx_node_mock = SimpleNamespace(
87+
name="aten_convolution_default",
88+
target="aten.convolution.default",
89+
meta={
90+
"from_node": [],
91+
},
92+
)
93+
94+
return fx_node_mock
95+
96+
97+
def _compare_tosa_and_schema(debug_event: DebugSchema, tosa_op):
98+
tosa_info = debug_event.tosa_info
99+
100+
assert tosa_info.node_name == tosa_op
101+
102+
# The mapping between op_ids to operator names could change
103+
# So just check operator_name is a string
104+
assert isinstance(tosa_info.operator_name, str)
105+
106+
107+
def _compare_node_and_schema(debug_event: DebugSchema, mocked_node):
108+
# Check aten info
109+
aten_info = debug_event.aten_info
110+
111+
assert aten_info.node_name == mocked_node.name
112+
assert aten_info.operator_name == mocked_node.target
113+
114+
# Check torch info
115+
torch_info = debug_event.torch_info
116+
117+
if "nn_module_stack" in mocked_node.meta:
118+
assert torch_info.nn_module_stack == mocked_node.meta["nn_module_stack"]
119+
else:
120+
assert torch_info.nn_module_stack == "No module stack trace available"
121+
122+
if "stack_trace" in mocked_node.meta:
123+
assert torch_info.stack_trace == mocked_node.meta["stack_trace"].split("\n")
124+
else:
125+
assert torch_info.stack_trace == ["No stack trace available"]
126+
127+
if "torch_fn" in mocked_node.meta:
128+
assert torch_info.torch_fn == mocked_node.meta["torch_fn"]
129+
else:
130+
assert torch_info.torch_fn == "No torch_fn available"
131+
132+
133+
TESTCASES = {
134+
"mocked_node": DebugHookTestCase(
135+
mock_node=create_mock_node_1(),
136+
tosa_op="layer-1",
137+
op_id=3,
138+
expected_events=1,
139+
num_nodes_traced=2,
140+
),
141+
"mocked_node_partially_empty": DebugHookTestCase(
142+
mock_node=create_mock_node_2(),
143+
tosa_op="layer-1",
144+
op_id=1,
145+
expected_events=1,
146+
num_nodes_traced=1,
147+
),
148+
"mocked_node_all_empty": DebugHookTestCase(
149+
mock_node=create_mock_node_3(),
150+
tosa_op="layer-2",
151+
op_id=1,
152+
expected_events=1,
153+
num_nodes_traced=0,
154+
),
155+
}
156+
157+
158+
@common.parametrize("test_data", TESTCASES)
159+
def test_debug_hook_add_1(test_data: DebugHookTestCase):
160+
hook = DebugHook()
161+
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)
162+
163+
debug_events = hook._debug_events
164+
assert len(debug_events) == test_data.expected_events
165+
assert len(debug_events[0].torch_info.node_trace) == test_data.num_nodes_traced
166+
167+
_compare_tosa_and_schema(debug_events[0], test_data.tosa_op)
168+
_compare_node_and_schema(debug_events[0], test_data.mock_node)

0 commit comments

Comments
 (0)