|
6 | 6 | # pyre-unsafe |
7 | 7 |
|
8 | 8 | import logging |
9 | | -import os |
10 | | -from typing import Any, Optional |
| 9 | +from typing import Any |
11 | 10 |
|
12 | 11 | import numpy as np |
13 | 12 | import serializer.tosa_serializer as ts # type: ignore |
|
20 | 19 |
|
21 | 20 | from executorch.backends.arm.tosa_specification import TosaSpecification |
22 | 21 | from executorch.exir.dialects._ops import ops as exir_ops |
23 | | -from executorch.exir.print_program import inspect_node |
24 | 22 |
|
25 | 23 | from torch._subclasses.fake_tensor import FakeTensor |
26 | 24 | from torch.fx import Node |
27 | 25 |
|
28 | 26 | logger = logging.getLogger(__name__) |
29 | 27 |
|
30 | 28 |
|
31 | | -def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): |
32 | | - # Debug output of node information |
33 | | - logger.info(get_node_debug_info(node, graph_module)) |
34 | | - |
35 | | - |
36 | | -def get_node_debug_info( |
37 | | - node: torch.fx.Node, graph_module: torch.fx.GraphModule | None = None |
38 | | -) -> str: |
39 | | - output = ( |
40 | | - f" {inspect_node(graph=graph_module.graph, node=node)}\n" |
41 | | - if graph_module |
42 | | - else "" |
43 | | - "-- NODE DEBUG INFO --\n" |
44 | | - f" Op is {node.op}\n" |
45 | | - f" Name is {node.name}\n" |
46 | | - f" Node target is {node.target}\n" |
47 | | - f" Node args is {node.args}\n" |
48 | | - f" Node kwargs is {node.kwargs}\n" |
49 | | - f" Node users is {node.users}\n" |
50 | | - " Node.meta = \n" |
51 | | - ) |
52 | | - for k, v in node.meta.items(): |
53 | | - if k == "stack_trace": |
54 | | - matches = v.split("\n") |
55 | | - output += " 'stack_trace =\n" |
56 | | - for m in matches: |
57 | | - output += f" {m}\n" |
58 | | - else: |
59 | | - output += f" '{k}' = {v}\n" |
60 | | - |
61 | | - if isinstance(v, list): |
62 | | - for i in v: |
63 | | - output += f" {i}\n" |
64 | | - return output |
65 | | - |
66 | | - |
67 | | -# Output TOSA flatbuffer and test harness file |
68 | | -def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): |
69 | | - filename = f"output{suffix}.tosa" |
70 | | - |
71 | | - logger.info(f"Emitting debug output to: {path=}, {suffix=}") |
72 | | - |
73 | | - os.makedirs(path, exist_ok=True) |
74 | | - |
75 | | - fb = tosa_graph.serialize() |
76 | | - js = tosa_graph.writeJson(filename) |
77 | | - |
78 | | - filepath_tosa_fb = os.path.join(path, filename) |
79 | | - with open(filepath_tosa_fb, "wb") as f: |
80 | | - f.write(fb) |
81 | | - assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer" |
82 | | - |
83 | | - filepath_desc_json = os.path.join(path, f"desc{suffix}.json") |
84 | | - with open(filepath_desc_json, "w") as f: |
85 | | - f.write(js) |
86 | | - assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" |
87 | | - |
88 | | - |
89 | | -def dbg_fail( |
90 | | - node, |
91 | | - graph_module, |
92 | | - tosa_graph: Optional[ts.TosaSerializer] = None, |
93 | | - path: Optional[str] = None, |
94 | | -): |
95 | | - logger.warning("Internal error due to poorly handled node:") |
96 | | - if tosa_graph is not None and path is not None: |
97 | | - dbg_tosa_dump(tosa_graph, path) |
98 | | - logger.warning(f"Debug output captured in '{path}'.") |
99 | | - dbg_node(node, graph_module) |
100 | | - |
101 | | - |
102 | 29 | def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]: |
103 | 30 | try: |
104 | 31 | return [TosaArg(arg, tosa_spec) for arg in node.args] |
|
0 commit comments