66
77
88def convert_to_submodules_graph (
9- original_gm : torch .fx .GraphModule ,
9+ gm : torch .fx .GraphModule ,
1010 split_positions : list [int ],
1111 submodule_hook = None ,
1212 submodule_name_prefix = "extracted_submodule" ,
1313 chain_style = False ,
1414 group_head_and_tail = True ,
1515):
1616 """
17- chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
17+ chain_style=True: decompose gm into g0 * g1 * g2 * g3
1818 """
19- original_gm = copy .deepcopy (original_gm )
19+ gm = copy .deepcopy (gm )
2020 num_placeholders = len (
21- [node for node in original_gm .graph .nodes if node .op == "placeholder" ]
21+ [node for node in gm .graph .nodes if node .op == "placeholder" ]
2222 )
2323 submodules_body_nodes = [
2424 node
25- for node in original_gm .graph .nodes
25+ for node in gm .graph .nodes
2626 if node .op
2727 not in {
2828 "placeholder" ,
@@ -54,20 +54,20 @@ def get_name2sub_submodule():
5454 )
5555 return {
5656 name : module
57- for name , module in original_gm .named_modules ()
57+ for name , module in gm .named_modules ()
5858 if name in used_module_names
5959 }
6060
6161 def get_start_node_idx (range_idx ):
6262 start_node = get_body_nodes (range_idx )[0 ]
63- for i , node in enumerate (original_gm .graph .nodes ):
63+ for i , node in enumerate (gm .graph .nodes ):
6464 if node == start_node :
6565 return i
6666 raise NotImplementedError ("Dead code." )
6767
6868 def get_end_node_idx (range_idx ):
6969 last_node = get_body_nodes (range_idx )[- 1 ]
70- for i , node in enumerate (original_gm .graph .nodes ):
70+ for i , node in enumerate (gm .graph .nodes ):
7171 if node == last_node :
7272 return i + 1
7373 raise NotImplementedError ("Dead code." )
@@ -78,23 +78,30 @@ def print_submodule_call(prompt, gm):
7878 ]
7979 print (f"{ prompt } " , submodule_call_stmts )
8080
81+ new_node2original_node = {}
82+ for node in gm .graph .nodes :
83+ new_node2original_node [node ] = node
84+
85+ def sort_key (node ):
86+ return new_node2original_node [node ].name
87+
8188 for range_idx in range (len (range_idx2submodule_body_nodes )):
8289 (
8390 submodule_input_nodes ,
8491 submodule_output_nodes ,
8592 identity_nodes ,
8693 ) = _get_submodule_inputs_and_outputs (
87- original_gm = original_gm ,
94+ gm = gm ,
8895 start_node_idx = get_start_node_idx (range_idx ),
8996 end_node_idx = get_end_node_idx (range_idx ),
9097 chain_style = chain_style ,
9198 )
9299
93100 def get_input_nodes (range_idx ):
94- return submodule_input_nodes
101+ return sorted ( submodule_input_nodes , key = sort_key )
95102
96103 def get_output_nodes (range_idx ):
97- return submodule_output_nodes
104+ return sorted ( submodule_output_nodes , key = sort_key )
98105
99106 submodule_name = (
100107 f"{ submodule_name_prefix } _{ range_idx } "
@@ -107,7 +114,8 @@ def get_output_nodes(range_idx):
107114
108115 # Add placeholder nodes for inputs
109116 for original_node in get_input_nodes (range_idx ):
110- new_node = new_graph .placeholder (original_node .name )
117+ name = new_node2original_node [original_node ].name
118+ new_node = new_graph .placeholder (name )
111119 node_map [original_node ] = new_node
112120
113121 # Copy body nodes
@@ -116,9 +124,9 @@ def get_output_nodes(range_idx):
116124 node_map [original_node ] = new_node
117125
118126 # Add output nodes
119- output_args = []
120- for original_node in get_output_nodes (range_idx ):
121- output_args . append ( node_map [ original_node ])
127+ output_args = [
128+ node_map [ original_node ] for original_node in get_output_nodes (range_idx )
129+ ]
122130 new_graph .output (tuple (output_args ))
123131
124132 # Create the new GraphModule
@@ -127,15 +135,15 @@ def get_output_nodes(range_idx):
127135 if submodule_hook is not None :
128136 new_sub_module = submodule_hook (new_sub_module , range_idx )
129137 # Replace with submodule node
130- original_gm .add_submodule (submodule_name , new_sub_module )
131- with original_gm .graph .inserting_after (get_body_nodes (range_idx )[- 1 ]):
132- submodule_node = original_gm .graph .call_module (
138+ gm .add_submodule (submodule_name , new_sub_module )
139+ with gm .graph .inserting_after (get_body_nodes (range_idx )[- 1 ]):
140+ submodule_node = gm .graph .call_module (
133141 submodule_name , tuple (get_input_nodes (range_idx ))
134142 )
135143 prev_node = submodule_node
136144 for idx , original_output in enumerate (get_output_nodes (range_idx )):
137- with original_gm .graph .inserting_after (prev_node ):
138- new_output_node = original_gm .graph .call_function (
145+ with gm .graph .inserting_after (prev_node ):
146+ new_output_node = gm .graph .call_function (
139147 operator .getitem , (submodule_node , idx )
140148 )
141149 node_map [original_output ] = new_output_node
@@ -146,31 +154,34 @@ def get_output_nodes(range_idx):
146154 for original_output in get_output_nodes (range_idx ):
147155 if original_output not in identity_node_set :
148156 original_output .replace_all_uses_with (node_map [original_output ])
157+ new_node2original_node [
158+ node_map [original_output ]
159+ ] = new_node2original_node [original_output ]
149160
150161 # Erase old nodes
151162 for node in reversed (get_body_nodes (range_idx )):
152- original_gm .graph .erase_node (node )
153- # print_submodule_call("(fx) after Erase old nodes", original_gm )
163+ gm .graph .erase_node (node )
164+ # print_submodule_call("(fx) after Erase old nodes", gm )
154165
155- # print_submodule_call("(fx) before recompile", original_gm )
166+ # print_submodule_call("(fx) before recompile", gm )
156167
157- original_gm .recompile ()
168+ gm .recompile ()
158169
159- # print_submodule_call("(fx) after recompile", original_gm )
170+ # print_submodule_call("(fx) after recompile", gm )
160171
161- return original_gm
172+ return gm
162173
163174
164175def fold_range_to_submodule (
165- original_gm : torch .fx .GraphModule ,
176+ gm : torch .fx .GraphModule ,
166177 start_node_idx : int ,
167178 end_node_idx : int ,
168179 submodule_hook = None ,
169180 submodule_name = "extracted_submodule" ,
170181 group_head_and_tail = True ,
171182):
172183 return convert_to_submodules_graph (
173- original_gm ,
184+ gm ,
174185 split_positions = [start_node_idx , end_node_idx ],
175186 submodule_hook = submodule_hook ,
176187 submodule_name_prefix = submodule_name ,
@@ -186,7 +197,7 @@ class NodeProducedOrConsumedCountCtx:
186197
187198
188199def _get_submodule_inputs_and_outputs (
189- original_gm : torch .fx .GraphModule ,
200+ gm : torch .fx .GraphModule ,
190201 start_node_idx : int ,
191202 end_node_idx : int ,
192203 chain_style = False ,
@@ -196,7 +207,7 @@ def _get_submodule_inputs_and_outputs(
196207 defaultdict (int ),
197208 defaultdict (int ),
198209 )
199- node_list = list (original_gm .graph .nodes )
210+ node_list = list (gm .graph .nodes )
200211
201212 def get_related_node (node ):
202213 for arg in node .args :
@@ -240,7 +251,7 @@ def get_related_node(node):
240251 if count_ctx .node2before_input [node ] > 0
241252 if count_ctx .node2body [node ] == 0
242253 if count_ctx .node2after_output [node ] > 0
243- ][: 1 ]
254+ ]
244255 input_nodes_set = set (input_nodes )
245256 input_nodes = [
246257 * input_nodes ,
@@ -251,5 +262,4 @@ def get_related_node(node):
251262 * output_nodes ,
252263 * [node for node in identity_nodes if node not in output_nodes_set ],
253264 ]
254-
255265 return input_nodes , output_nodes , identity_nodes
0 commit comments