Skip to content

Commit 3899beb

Browse files
tom-armmansnils
andauthored
Arm backend: Propagate debug info to TOSA flatbuffer (#13998)
* Add TOSA option to DebugMode enum Co-authored-by: Måns Nilsson <[email protected]>
1 parent 9dfc0d6 commit 3899beb

File tree

6 files changed

+94
-31
lines changed

6 files changed

+94
-31
lines changed

backends/arm/arm_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
class ArmCompileSpecBuilder:
2424
class DebugMode(Enum):
2525
JSON = 1
26+
TOSA = 2
2627

2728
def __init__(self):
2829
self.compile_spec: List[CompileSpec] = []

backends/arm/debug/schema.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
import json
99

1010
from dataclasses import asdict, dataclass
11-
from typing import Any
11+
from typing import Any, Optional
1212

1313
import serializer.tosa_serializer as ts # type: ignore
1414
import torch
1515

16+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
17+
1618
from torch.fx.traceback import NodeSource
1719

1820

@@ -97,37 +99,52 @@ def from_node(node: torch.fx.Node) -> TorchDebugSchema:
9799
class DebugSchema:
98100
event_id: int
99101
aten_info: ATenDebugSchema
100-
tosa_info: TosaDebugSchema
102+
tosa_info: Optional[TosaDebugSchema]
101103
torch_info: TorchDebugSchema
102104

105+
def to_dict(self) -> dict[str, Any]:
106+
output = asdict(self)
107+
108+
if self.tosa_info is None:
109+
output.pop("tosa_info")
110+
111+
return output
112+
103113

104114
class DebugHook:
105-
def __init__(self) -> None:
115+
def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None:
106116
self._debug_events: list[DebugSchema] = []
107117
self.__op_id_to_name = {}
118+
self.mode = debug_mode
108119

109120
# Build up a mapping from TOSA 1.0 operator IDs to their names
110121
for name, val in vars(ts.Op).items():
111122
self.__op_id_to_name[val] = name
112123

113-
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None:
114-
tosa_debug_info = TosaDebugSchema(
115-
node_name=str(tosa_op),
116-
operator_name=self.__op_id_to_name[tosa_op_id],
117-
operator_id=tosa_op_id,
118-
)
124+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
125+
tosa_debug_info = None
126+
127+
# If the debug data is being embedded into the TOSA flatbuffer
128+
# do not collect TOSADebugSchema data, it's redundent
129+
if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA:
130+
tosa_debug_info = TosaDebugSchema(
131+
node_name=str(tosa_op),
132+
operator_name=self.__op_id_to_name[tosa_op_id],
133+
operator_id=tosa_op_id,
134+
)
119135

120136
aten_debug_info = ATenDebugSchema.from_node(node)
121137
torch_debug_info = TorchDebugSchema.from_node(node)
122138

123-
self._debug_events.append(
124-
DebugSchema(
125-
event_id=len(self._debug_events),
126-
aten_info=aten_debug_info,
127-
tosa_info=tosa_debug_info,
128-
torch_info=torch_debug_info,
129-
)
139+
debug_info = DebugSchema(
140+
event_id=len(self._debug_events),
141+
aten_info=aten_debug_info,
142+
tosa_info=tosa_debug_info,
143+
torch_info=torch_debug_info,
130144
)
145+
self._debug_events.append(debug_info)
146+
147+
return debug_info
131148

132149
def serialize(self) -> str:
133-
return json.dumps([asdict(event) for event in self._debug_events], indent=4)
150+
return json.dumps([event.to_dict() for event in self._debug_events], indent=4)

backends/arm/operators/node_visitor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
# pyre-unsafe
77

8+
import json
89
from typing import Any, Dict, List, Optional
910

1011
import torch
1112

13+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
1214
from executorch.backends.arm.debug.schema import DebugHook
1315
from executorch.backends.arm.tosa.mapping import TosaArg
1416
from executorch.backends.arm.tosa.specification import TosaSpecification
@@ -49,20 +51,25 @@ def _serialize_operator(
4951
outputs: List[str],
5052
attributes: Optional[Any] = None,
5153
) -> None:
54+
op_location = ""
55+
if self.debug_hook:
56+
debug_info = self.debug_hook.add(
57+
node,
58+
tosa_op=outputs[0],
59+
tosa_op_id=tosa_op,
60+
)
61+
62+
if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA:
63+
op_location = json.dumps(debug_info.to_dict())
64+
5265
tosa_graph.addOperator(
5366
tosa_op,
5467
inputs=inputs,
5568
outputs=outputs,
5669
attributes=attributes,
70+
location=op_location,
5771
)
5872

59-
if self.debug_hook:
60-
self.debug_hook.add(
61-
node,
62-
tosa_op=outputs[0],
63-
tosa_op_id=tosa_op,
64-
)
65-
6673
def define_node(
6774
self,
6875
node: torch.fx.Node,

backends/arm/test/misc/test_debug_feats.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,27 @@ def test_dump_tosa_debug_json(test_data: input_t1):
222222
pytest.fail("Failed to load debug JSON file")
223223

224224

225+
@common.parametrize("test_data", Linear.inputs)
226+
def test_dump_tosa_debug_tosa(test_data: input_t1):
227+
with tempfile.TemporaryDirectory() as tmpdir:
228+
pipeline = TosaPipelineINT[input_t1](
229+
module=Linear(),
230+
test_data=test_data,
231+
aten_op=[],
232+
exir_op=[],
233+
custom_path=tmpdir,
234+
tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.TOSA,
235+
)
236+
237+
pipeline.pop_stage("run_method_and_compare_outputs")
238+
pipeline.run()
239+
240+
json_output_path = Path(tmpdir) / "debug.json"
241+
242+
# A JSON file should not be created when TOSA mode used
243+
assert not json_output_path.exists()
244+
245+
225246
@common.parametrize("test_data", Linear.inputs)
226247
def test_dump_tosa_ops(caplog, test_data: input_t1):
227248
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], [])

backends/arm/test/misc/test_debug_hook.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dataclasses import dataclass
77
from types import SimpleNamespace
88

9+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
910
from executorch.backends.arm.debug.schema import DebugHook, DebugSchema
1011
from executorch.backends.arm.test import common
1112

@@ -156,8 +157,8 @@ def _compare_node_and_schema(debug_event: DebugSchema, mocked_node):
156157

157158

158159
@common.parametrize("test_data", TESTCASES)
159-
def test_debug_hook_add_1(test_data: DebugHookTestCase):
160-
hook = DebugHook()
160+
def test_debug_hook_add_json(test_data: DebugHookTestCase):
161+
hook = DebugHook(ArmCompileSpecBuilder.DebugMode.JSON)
161162
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)
162163

163164
debug_events = hook._debug_events
@@ -166,3 +167,17 @@ def test_debug_hook_add_1(test_data: DebugHookTestCase):
166167

167168
_compare_tosa_and_schema(debug_events[0], test_data.tosa_op)
168169
_compare_node_and_schema(debug_events[0], test_data.mock_node)
170+
171+
172+
@common.parametrize("test_data", TESTCASES)
173+
def test_debug_hook_add_tosa(test_data: DebugHookTestCase):
174+
hook = DebugHook(ArmCompileSpecBuilder.DebugMode.TOSA)
175+
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)
176+
177+
debug_events = hook._debug_events
178+
assert len(debug_events) == test_data.expected_events
179+
assert len(debug_events[0].torch_info.node_trace) == test_data.num_nodes_traced
180+
181+
assert debug_events[0].tosa_info is None
182+
183+
_compare_node_and_schema(debug_events[0], test_data.mock_node)

backends/arm/tosa/backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import cast, final, List
1515

1616
import serializer.tosa_serializer as ts # type: ignore
17+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
1718
from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump
1819
from executorch.backends.arm.debug.schema import DebugHook
1920
from executorch.backends.arm.process_node import (
@@ -100,7 +101,7 @@ def preprocess( # noqa: C901
100101

101102
debug_hook = None
102103
if dump_debug_info is not None:
103-
debug_hook = DebugHook()
104+
debug_hook = DebugHook(ArmCompileSpecBuilder.DebugMode[dump_debug_info])
104105

105106
# TODO: Fix the need to lazily import this.
106107
from executorch.backends.arm.operators.node_visitor import get_node_visitors
@@ -136,10 +137,11 @@ def preprocess( # noqa: C901
136137
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
137138
)
138139

139-
if debug_hook:
140-
json_output = debug_hook.serialize()
141-
with open(f"{artifact_path}/debug.json", "w") as f:
142-
f.write(json_output)
140+
if debug_hook is not None:
141+
if debug_hook.mode == ArmCompileSpecBuilder.DebugMode.JSON:
142+
json_output = debug_hook.serialize()
143+
with open(f"{artifact_path}/debug.json", "w") as f:
144+
f.write(json_output)
143145

144146
# Serialize and return the TOSA flatbuffer.
145147
binary = bytes(tosa_graph.serialize())

0 commit comments

Comments
 (0)