|
8 | 8 | import json
|
9 | 9 |
|
10 | 10 | from dataclasses import asdict, dataclass
|
11 |
| -from typing import Any |
| 11 | +from typing import Any, Optional |
12 | 12 |
|
13 | 13 | import serializer.tosa_serializer as ts # type: ignore
|
14 | 14 | import torch
|
15 | 15 |
|
| 16 | +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder |
| 17 | + |
16 | 18 | from torch.fx.traceback import NodeSource
|
17 | 19 |
|
18 | 20 |
|
@@ -97,37 +99,52 @@ def from_node(node: torch.fx.Node) -> TorchDebugSchema:
|
97 | 99 | class DebugSchema:
|
98 | 100 | event_id: int
|
99 | 101 | aten_info: ATenDebugSchema
|
100 |
| - tosa_info: TosaDebugSchema |
| 102 | + tosa_info: Optional[TosaDebugSchema] |
101 | 103 | torch_info: TorchDebugSchema
|
102 | 104 |
|
| 105 | + def to_dict(self) -> dict[str, Any]: |
| 106 | + output = asdict(self) |
| 107 | + |
| 108 | + if self.tosa_info is None: |
| 109 | + output.pop("tosa_info") |
| 110 | + |
| 111 | + return output |
| 112 | + |
103 | 113 |
|
104 | 114 | class DebugHook:
|
105 |
| - def __init__(self) -> None: |
| 115 | + def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None: |
106 | 116 | self._debug_events: list[DebugSchema] = []
|
107 | 117 | self.__op_id_to_name = {}
|
| 118 | + self.mode = debug_mode |
108 | 119 |
|
109 | 120 | # Build up a mapping from TOSA 1.0 operator IDs to their names
|
110 | 121 | for name, val in vars(ts.Op).items():
|
111 | 122 | self.__op_id_to_name[val] = name
|
112 | 123 |
|
113 |
| - def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None: |
114 |
| - tosa_debug_info = TosaDebugSchema( |
115 |
| - node_name=str(tosa_op), |
116 |
| - operator_name=self.__op_id_to_name[tosa_op_id], |
117 |
| - operator_id=tosa_op_id, |
118 |
| - ) |
| 124 | + def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema: |
| 125 | + tosa_debug_info = None |
| 126 | + |
| 127 | + # If the debug data is being embedded into the TOSA flatbuffer |
| 128 | + # do not collect TOSADebugSchema data, it's redundent |
| 129 | + if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA: |
| 130 | + tosa_debug_info = TosaDebugSchema( |
| 131 | + node_name=str(tosa_op), |
| 132 | + operator_name=self.__op_id_to_name[tosa_op_id], |
| 133 | + operator_id=tosa_op_id, |
| 134 | + ) |
119 | 135 |
|
120 | 136 | aten_debug_info = ATenDebugSchema.from_node(node)
|
121 | 137 | torch_debug_info = TorchDebugSchema.from_node(node)
|
122 | 138 |
|
123 |
| - self._debug_events.append( |
124 |
| - DebugSchema( |
125 |
| - event_id=len(self._debug_events), |
126 |
| - aten_info=aten_debug_info, |
127 |
| - tosa_info=tosa_debug_info, |
128 |
| - torch_info=torch_debug_info, |
129 |
| - ) |
| 139 | + debug_info = DebugSchema( |
| 140 | + event_id=len(self._debug_events), |
| 141 | + aten_info=aten_debug_info, |
| 142 | + tosa_info=tosa_debug_info, |
| 143 | + torch_info=torch_debug_info, |
130 | 144 | )
|
| 145 | + self._debug_events.append(debug_info) |
| 146 | + |
| 147 | + return debug_info |
131 | 148 |
|
132 | 149 | def serialize(self) -> str:
|
133 |
| - return json.dumps([asdict(event) for event in self._debug_events], indent=4) |
| 150 | + return json.dumps([event.to_dict() for event in self._debug_events], indent=4) |
0 commit comments