|
6 | 6 | # pyre-unsafe
|
7 | 7 |
|
8 | 8 | import logging
|
9 |
| -import os |
10 |
| -from typing import Any, Optional |
| 9 | +from typing import Any |
11 | 10 |
|
12 | 11 | import numpy as np
|
13 | 12 | import serializer.tosa_serializer as ts # type: ignore
|
|
20 | 19 |
|
21 | 20 | from executorch.backends.arm.tosa_specification import TosaSpecification
|
22 | 21 | from executorch.exir.dialects._ops import ops as exir_ops
|
23 |
| -from executorch.exir.print_program import inspect_node |
24 | 22 |
|
25 | 23 | from torch._subclasses.fake_tensor import FakeTensor
|
26 | 24 | from torch.fx import Node
|
27 | 25 |
|
28 | 26 | logger = logging.getLogger(__name__)
|
29 | 27 |
|
30 | 28 |
|
31 |
| -def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): |
32 |
| - # Debug output of node information |
33 |
| - logger.info(get_node_debug_info(node, graph_module)) |
34 |
| - |
35 |
| - |
36 |
| -def get_node_debug_info( |
37 |
| - node: torch.fx.Node, graph_module: torch.fx.GraphModule | None = None |
38 |
| -) -> str: |
39 |
| - output = ( |
40 |
| - f" {inspect_node(graph=graph_module.graph, node=node)}\n" |
41 |
| - if graph_module |
42 |
| - else "" |
43 |
| - "-- NODE DEBUG INFO --\n" |
44 |
| - f" Op is {node.op}\n" |
45 |
| - f" Name is {node.name}\n" |
46 |
| - f" Node target is {node.target}\n" |
47 |
| - f" Node args is {node.args}\n" |
48 |
| - f" Node kwargs is {node.kwargs}\n" |
49 |
| - f" Node users is {node.users}\n" |
50 |
| - " Node.meta = \n" |
51 |
| - ) |
52 |
| - for k, v in node.meta.items(): |
53 |
| - if k == "stack_trace": |
54 |
| - matches = v.split("\n") |
55 |
| - output += " 'stack_trace =\n" |
56 |
| - for m in matches: |
57 |
| - output += f" {m}\n" |
58 |
| - else: |
59 |
| - output += f" '{k}' = {v}\n" |
60 |
| - |
61 |
| - if isinstance(v, list): |
62 |
| - for i in v: |
63 |
| - output += f" {i}\n" |
64 |
| - return output |
65 |
| - |
66 |
| - |
67 |
| -# Output TOSA flatbuffer and test harness file |
68 |
| -def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): |
69 |
| - filename = f"output{suffix}.tosa" |
70 |
| - |
71 |
| - logger.info(f"Emitting debug output to: {path=}, {suffix=}") |
72 |
| - |
73 |
| - os.makedirs(path, exist_ok=True) |
74 |
| - |
75 |
| - fb = tosa_graph.serialize() |
76 |
| - js = tosa_graph.writeJson(filename) |
77 |
| - |
78 |
| - filepath_tosa_fb = os.path.join(path, filename) |
79 |
| - with open(filepath_tosa_fb, "wb") as f: |
80 |
| - f.write(fb) |
81 |
| - assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer" |
82 |
| - |
83 |
| - filepath_desc_json = os.path.join(path, f"desc{suffix}.json") |
84 |
| - with open(filepath_desc_json, "w") as f: |
85 |
| - f.write(js) |
86 |
| - assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" |
87 |
| - |
88 |
| - |
89 |
| -def dbg_fail( |
90 |
| - node, |
91 |
| - graph_module, |
92 |
| - tosa_graph: Optional[ts.TosaSerializer] = None, |
93 |
| - path: Optional[str] = None, |
94 |
| -): |
95 |
| - logger.warning("Internal error due to poorly handled node:") |
96 |
| - if tosa_graph is not None and path is not None: |
97 |
| - dbg_tosa_dump(tosa_graph, path) |
98 |
| - logger.warning(f"Debug output captured in '{path}'.") |
99 |
| - dbg_node(node, graph_module) |
100 |
| - |
101 |
| - |
102 | 29 | def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]:
|
103 | 30 | try:
|
104 | 31 | return [TosaArg(arg, tosa_spec) for arg in node.args]
|
|
0 commit comments