Skip to content

Commit 100c16a

Browse files
committed
Update on "restucture debug handle"
Differential Revision: [D66622890](https://our.internmc.facebook.com/intern/diff/D66622890/) [ghstack-poisoned]
2 parents 7f0ed4b + ca4f428 commit 100c16a

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

exir/graph_module.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from types import FunctionType as function
10-
from typing import Any, Callable, Dict, List, Tuple, Union
10+
from typing import Callable, Dict, List, Tuple, Union
1111

1212
import torch
1313

@@ -69,15 +69,13 @@ def get_control_flow_submodules(
6969

7070
return control_flow_submodules
7171

72-
# TODO(gasoonjia): remove this and leverage core pytorch bfs_trace_with_node_process after code freeze
72+
7373
def bfs_trace_with_node_process(
7474
gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None]
7575
) -> None:
7676
"""Traverse the graph module and apply node_op to each node."""
7777

78-
assert isinstance(
79-
gm, torch.fx.GraphModule
80-
), f"Expected GraphModule, got {type(gm)}"
78+
assert isinstance(gm, torch.fx.GraphModule), f"Expected GraphModule, got {type(gm)}"
8179

8280
queue = [gm]
8381
while queue:

exir/passes/debug_handle_generator_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.fx import GraphModule
1111
from torch.fx.passes.infra.pass_base import PassResult
1212

13+
1314
class DebugHandleGeneratorPass(ExportPass):
1415
def call(self, graph_module: GraphModule) -> PassResult:
1516
"""Lower a quantized reference model (with reference quantized operator patterns)

0 commit comments

Comments
 (0)