Skip to content

Commit ad19cb8

Browse files
authored
Arm backend: Enable capture of debug information during TOSA serialization (#13908)
* Add _serialize_operator() function to NodeVisitor * Enable passing of DebugHook to NodeVisitor if enabled * Add option dump_debug_info to ArmCompileSpecBuilder
1 parent d7fd78b commit ad19cb8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+386
-90
lines changed

backends/arm/arm_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# backends. Converts via TOSA as an intermediate form supported by AoT and
1111
# JIT compiler flows.
1212
#
13+
from enum import Enum
1314
from typing import List, Optional
1415

1516
from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found]
@@ -22,12 +23,16 @@
2223

2324

2425
class ArmCompileSpecBuilder:
26+
class DebugMode(Enum):
27+
JSON = 1
28+
2529
def __init__(self):
2630
self.compile_spec: List[CompileSpec] = []
2731
self.compiler_flags = []
2832
self.output_format = None
2933
self.path_for_intermediates = None
3034
self.tosa_spec = None
35+
self.tosa_debug_mode = None
3136

3237
def vgf_compile_spec(
3338
self,
@@ -163,6 +168,13 @@ def dump_intermediate_artifacts_to(
163168
self.path_for_intermediates = output_path
164169
return self
165170

171+
def dump_debug_info(self, debug_mode: DebugMode) -> "ArmCompileSpecBuilder":
172+
"""
173+
Dump debugging information into the intermediates path
174+
"""
175+
self.tosa_debug_mode = debug_mode.name
176+
return self
177+
166178
def build(self) -> List[CompileSpec]:
167179
"""
168180
Generate a list of compile spec objects from the builder
@@ -188,6 +200,16 @@ def build(self) -> List[CompileSpec]:
188200
CompileSpec("debug_artifact_path", self.path_for_intermediates.encode())
189201
)
190202

203+
if self.tosa_debug_mode is not None:
204+
if not self.path_for_intermediates:
205+
raise ValueError(
206+
"dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()"
207+
)
208+
209+
self.compile_spec.append(
210+
CompileSpec("dump_debug_info", self.tosa_debug_mode.encode())
211+
)
212+
191213
return self.compile_spec
192214

193215

backends/arm/operators/node_visitor.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
# pyre-unsafe
77

8-
from typing import Any, Dict, List
8+
from typing import Any, Dict, List, Optional
99

1010
import torch
1111

12+
from executorch.backends.arm.debug.schema import DebugHook
1213
from executorch.backends.arm.tosa_mapping import TosaArg
1314
from executorch.backends.arm.tosa_specification import TosaSpecification
1415
from torch.export import ExportedProgram
@@ -29,9 +30,38 @@ class NodeVisitor:
2930
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3031
]
3132

32-
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
33+
def __init__(
34+
self,
35+
exported_program: ExportedProgram,
36+
tosa_spec: TosaSpecification,
37+
debug_hook: Optional[DebugHook] = None,
38+
):
3339
self._exported_program = exported_program
3440
self.tosa_spec = tosa_spec
41+
self.debug_hook = debug_hook
42+
43+
def _serialize_operator(
44+
self,
45+
node: torch.fx.Node,
46+
tosa_graph: Any,
47+
tosa_op: Any,
48+
inputs: List[str],
49+
outputs: List[str],
50+
attributes: Optional[Any] = None,
51+
) -> None:
52+
tosa_graph.addOperator(
53+
tosa_op,
54+
inputs=inputs,
55+
outputs=outputs,
56+
attributes=attributes,
57+
)
58+
59+
if self.debug_hook:
60+
self.debug_hook.add(
61+
node,
62+
tosa_op=outputs[0],
63+
tosa_op_id=tosa_op,
64+
)
3565

3666
def define_node(
3767
self,

backends/arm/operators/op_abs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def define_node(
123123
)
124124

125125
# MI lowering
126-
tosa_graph.addOperator(
126+
self._serialize_operator(
127+
node,
128+
tosa_graph,
127129
ts.TosaOp.Op().ABS,
128130
[inputs[0].name],
129131
[output.name],

backends/arm/operators/op_add.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def define_node(
7373
input1, input2 = rescaled_inputs
7474

7575
# Do the INT32 Add
76-
tosa_graph.addOperator(
76+
self._serialize_operator(
77+
node,
78+
tosa_graph,
7779
ts.TosaOp.Op().ADD,
7880
[input1.name, input2.name],
7981
[add_output.name],
@@ -127,7 +129,9 @@ def define_node(
127129
input1, input2 = inputs
128130

129131
# FP lowering
130-
tosa_graph.addOperator(
132+
self._serialize_operator(
133+
node,
134+
tosa_graph,
131135
ts.TosaOp.Op().ADD,
132136
[input1.name, input2.name],
133137
[output.name],

backends/arm/operators/op_amax.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def define_node(
6161

6262
attr = ts.TosaSerializerAttribute()
6363
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1)
64-
tosa_graph.addOperator(
65-
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
64+
self._serialize_operator(
65+
node,
66+
tosa_graph,
67+
ts.TosaOp.Op().REDUCE_MAX,
68+
[input.name],
69+
[output.name],
70+
attr,
6671
)

backends/arm/operators/op_amin.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def define_node(
6161

6262
attr = ts.TosaSerializerAttribute()
6363
attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1)
64-
tosa_graph.addOperator(
65-
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
64+
self._serialize_operator(
65+
node,
66+
tosa_graph,
67+
ts.TosaOp.Op().REDUCE_MIN,
68+
[input.name],
69+
[output.name],
70+
attr,
6671
)

backends/arm/operators/op_any.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def define_node(
5252
attr = ts.TosaSerializerAttribute()
5353
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))
5454

55-
tosa_graph.addOperator(
56-
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
55+
self._serialize_operator(
56+
node,
57+
tosa_graph,
58+
ts.TosaOp.Op().REDUCE_ANY,
59+
[inputs[0].name],
60+
[output.name],
61+
attr,
5762
)

backends/arm/operators/op_avg_pool2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def _build_generic_avgpool2d(
100100
shape=[1], dtype=output.dtype, vals=[output_zp]
101101
)
102102

103-
tosa_graph.addOperator(
103+
self._serialize_operator(
104+
node,
105+
tosa_graph,
104106
ts.TosaOp.Op().AVG_POOL2D,
105107
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
106108
[output.name],

backends/arm/operators/op_bmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def define_node(
7878
tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=f"{node.name}_B_ZP")
7979

8080
# Add the MATMUL to the TOSA graph.
81-
tosa_graph.addOperator(
81+
self._serialize_operator(
82+
node,
83+
tosa_graph,
8284
ts.TosaOp.Op().MATMUL,
8385
[
8486
inputs[0].name,

backends/arm/operators/op_cat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def define_node(
4747
attr = ts.TosaSerializerAttribute()
4848
attr.ConcatAttribute(dim)
4949

50-
tosa_graph.addOperator(
50+
self._serialize_operator(
51+
node,
52+
tosa_graph,
5153
ts.TosaOp.Op().CONCAT,
5254
[tensor.name for tensor in tensors],
5355
[output.name],

0 commit comments

Comments
 (0)