Skip to content

Commit 538a90d

Browse files
committed
merge develop
2 parents b06e8ec + 06b8dc2 commit 538a90d

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ MODEL_PATH_IN_SAMPLES=/timm/resnet18
55
read -r -d '' json_str <<'EOF'
66
{
77
"output_dir": "/tmp/naive_decompose_workspace",
8-
"split_positions": [0, 32],
8+
"split_positions": [8, 32],
99
"group_head_and_tail": true
1010
}
1111
EOF

graph_net/torch/decompose_util.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ def get_end_node_idx(range_idx):
7272
return i + 1
7373
raise NotImplementedError("Dead code.")
7474

75+
def print_submodule_call(prompt, gm):
76+
submodule_call_stmts = [
77+
stmt for stmt in gm.code.split("\n") if "self.extracted_submodule" in stmt
78+
]
79+
print(f"{prompt} ", submodule_call_stmts)
80+
7581
for range_idx in range(len(range_idx2submodule_body_nodes)):
7682
(
7783
submodule_input_nodes,
@@ -144,9 +150,14 @@ def get_output_nodes(range_idx):
144150
# Erase old nodes
145151
for node in reversed(get_body_nodes(range_idx)):
146152
original_gm.graph.erase_node(node)
153+
# print_submodule_call("(fx) after Erase old nodes", original_gm)
154+
155+
# print_submodule_call("(fx) before recompile", original_gm)
147156

148157
original_gm.recompile()
149158

159+
# print_submodule_call("(fx) after recompile", original_gm)
160+
150161
return original_gm
151162

152163

0 commit comments

Comments
 (0)