File tree Expand file tree Collapse file tree 2 files changed +12
-1
lines changed
Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -5,7 +5,7 @@ MODEL_PATH_IN_SAMPLES=/timm/resnet18
55read -r -d ' ' json_str << 'EOF '
66{
77 "output_dir": "/tmp/naive_decompose_workspace",
8- "split_positions": [0 , 32],
8+ "split_positions": [8 , 32],
99 "group_head_and_tail": true
1010}
1111EOF
Original file line number Diff line number Diff line change @@ -72,6 +72,12 @@ def get_end_node_idx(range_idx):
7272 return i + 1
7373 raise NotImplementedError ("Dead code." )
7474
75+ def print_submodule_call (prompt , gm ):
76+ submodule_call_stmts = [
77+ stmt for stmt in gm .code .split ("\n " ) if "self.extracted_submodule" in stmt
78+ ]
79+ print (f"{ prompt } " , submodule_call_stmts )
80+
7581 for range_idx in range (len (range_idx2submodule_body_nodes )):
7682 (
7783 submodule_input_nodes ,
@@ -144,9 +150,14 @@ def get_output_nodes(range_idx):
144150 # Erase old nodes
145151 for node in reversed (get_body_nodes (range_idx )):
146152 original_gm .graph .erase_node (node )
153+ # print_submodule_call("(fx) after Erase old nodes", original_gm)
154+
155+ # print_submodule_call("(fx) before recompile", original_gm)
147156
148157 original_gm .recompile ()
149158
159+ # print_submodule_call("(fx) after recompile", original_gm)
160+
150161 return original_gm
151162
152163
You can’t perform that action at this time.
0 commit comments