66
77
88def convert_to_submodules_graph (
9- original_gm : torch .fx .GraphModule ,
9+ gm : torch .fx .GraphModule ,
1010 split_positions : list [int ],
1111 submodule_hook = None ,
1212 submodule_name_prefix = "extracted_submodule" ,
1313 chain_style = False ,
1414 group_head_and_tail = True ,
1515):
1616 """
17- chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
17+ chain_style=True: decompose gm into g0 * g1 * g2 * g3
1818 """
19- original_gm = copy .deepcopy (original_gm )
19+ gm = copy .deepcopy (gm )
2020 num_placeholders = len (
21- [node for node in original_gm .graph .nodes if node .op == "placeholder" ]
21+ [node for node in gm .graph .nodes if node .op == "placeholder" ]
2222 )
2323 submodules_body_nodes = [
2424 node
25- for node in original_gm .graph .nodes
25+ for node in gm .graph .nodes
2626 if node .op
2727 not in {
2828 "placeholder" ,
@@ -37,13 +37,16 @@ def convert_to_submodules_graph(
3737 split_positions = [
3838 max (0 , min (pos , len (submodules_body_nodes ))) for pos in split_positions
3939 ]
40- range_idx2submodule_body_nodes = [
41- submodules_body_nodes [ start : end ]
40+ range_idx2range = [
41+ ( start , end )
4242 for i in range (len (split_positions ) - 1 )
4343 for start in [split_positions [i ]]
4444 for end in [split_positions [i + 1 ]]
4545 if end > start
4646 ]
47+ range_idx2submodule_body_nodes = [
48+ submodules_body_nodes [start :end ] for start , end in range_idx2range
49+ ]
4750
4851 def get_body_nodes (range_idx ):
4952 return range_idx2submodule_body_nodes [range_idx ]
@@ -54,20 +57,20 @@ def get_name2sub_submodule():
5457 )
5558 return {
5659 name : module
57- for name , module in original_gm .named_modules ()
60+ for name , module in gm .named_modules ()
5861 if name in used_module_names
5962 }
6063
6164 def get_start_node_idx (range_idx ):
6265 start_node = get_body_nodes (range_idx )[0 ]
63- for i , node in enumerate (original_gm .graph .nodes ):
66+ for i , node in enumerate (gm .graph .nodes ):
6467 if node == start_node :
6568 return i
6669 raise NotImplementedError ("Dead code." )
6770
6871 def get_end_node_idx (range_idx ):
6972 last_node = get_body_nodes (range_idx )[- 1 ]
70- for i , node in enumerate (original_gm .graph .nodes ):
73+ for i , node in enumerate (gm .graph .nodes ):
7174 if node == last_node :
7275 return i + 1
7376 raise NotImplementedError ("Dead code." )
@@ -78,23 +81,33 @@ def print_submodule_call(prompt, gm):
7881 ]
7982 print (f"{ prompt } " , submodule_call_stmts )
8083
84+ new_node2original_node = {}
85+ for node in gm .graph .nodes :
86+ new_node2original_node [node ] = node
87+
88+ def sort_key (node ):
89+ return new_node2original_node [node ].name
90+
8191 for range_idx in range (len (range_idx2submodule_body_nodes )):
8292 (
8393 submodule_input_nodes ,
8494 submodule_output_nodes ,
8595 identity_nodes ,
8696 ) = _get_submodule_inputs_and_outputs (
87- original_gm = original_gm ,
97+ gm = gm ,
8898 start_node_idx = get_start_node_idx (range_idx ),
8999 end_node_idx = get_end_node_idx (range_idx ),
90100 chain_style = chain_style ,
91101 )
92102
93103 def get_input_nodes (range_idx ):
94- return submodule_input_nodes
104+ return sorted ( submodule_input_nodes , key = sort_key )
95105
96106 def get_output_nodes (range_idx ):
97- return submodule_output_nodes
107+ end = range_idx2range [range_idx ][1 ]
108+ if end >= len (submodules_body_nodes ):
109+ return submodule_output_nodes
110+ return sorted (submodule_output_nodes , key = sort_key )
98111
99112 submodule_name = (
100113 f"{ submodule_name_prefix } _{ range_idx } "
@@ -107,7 +120,8 @@ def get_output_nodes(range_idx):
107120
108121 # Add placeholder nodes for inputs
109122 for original_node in get_input_nodes (range_idx ):
110- new_node = new_graph .placeholder (original_node .name )
123+ name = new_node2original_node [original_node ].name
124+ new_node = new_graph .placeholder (name )
111125 node_map [original_node ] = new_node
112126
113127 # Copy body nodes
@@ -116,9 +130,9 @@ def get_output_nodes(range_idx):
116130 node_map [original_node ] = new_node
117131
118132 # Add output nodes
119- output_args = []
120- for original_node in get_output_nodes (range_idx ):
121- output_args . append ( node_map [ original_node ])
133+ output_args = [
134+ node_map [ original_node ] for original_node in get_output_nodes (range_idx )
135+ ]
122136 new_graph .output (tuple (output_args ))
123137
124138 # Create the new GraphModule
@@ -127,15 +141,15 @@ def get_output_nodes(range_idx):
127141 if submodule_hook is not None :
128142 new_sub_module = submodule_hook (new_sub_module , range_idx )
129143 # Replace with submodule node
130- original_gm .add_submodule (submodule_name , new_sub_module )
131- with original_gm .graph .inserting_after (get_body_nodes (range_idx )[- 1 ]):
132- submodule_node = original_gm .graph .call_module (
144+ gm .add_submodule (submodule_name , new_sub_module )
145+ with gm .graph .inserting_after (get_body_nodes (range_idx )[- 1 ]):
146+ submodule_node = gm .graph .call_module (
133147 submodule_name , tuple (get_input_nodes (range_idx ))
134148 )
135149 prev_node = submodule_node
136150 for idx , original_output in enumerate (get_output_nodes (range_idx )):
137- with original_gm .graph .inserting_after (prev_node ):
138- new_output_node = original_gm .graph .call_function (
151+ with gm .graph .inserting_after (prev_node ):
152+ new_output_node = gm .graph .call_function (
139153 operator .getitem , (submodule_node , idx )
140154 )
141155 node_map [original_output ] = new_output_node
@@ -146,31 +160,34 @@ def get_output_nodes(range_idx):
146160 for original_output in get_output_nodes (range_idx ):
147161 if original_output not in identity_node_set :
148162 original_output .replace_all_uses_with (node_map [original_output ])
163+ new_node2original_node [
164+ node_map [original_output ]
165+ ] = new_node2original_node [original_output ]
149166
150167 # Erase old nodes
151168 for node in reversed (get_body_nodes (range_idx )):
152- original_gm .graph .erase_node (node )
153- # print_submodule_call("(fx) after Erase old nodes", original_gm )
169+ gm .graph .erase_node (node )
170+ # print_submodule_call("(fx) after Erase old nodes", gm )
154171
155- # print_submodule_call("(fx) before recompile", original_gm )
172+ # print_submodule_call("(fx) before recompile", gm )
156173
157- original_gm .recompile ()
174+ gm .recompile ()
158175
159- # print_submodule_call("(fx) after recompile", original_gm )
176+ # print_submodule_call("(fx) after recompile", gm )
160177
161- return original_gm
178+ return gm
162179
163180
164181def fold_range_to_submodule (
165- original_gm : torch .fx .GraphModule ,
182+ gm : torch .fx .GraphModule ,
166183 start_node_idx : int ,
167184 end_node_idx : int ,
168185 submodule_hook = None ,
169186 submodule_name = "extracted_submodule" ,
170187 group_head_and_tail = True ,
171188):
172189 return convert_to_submodules_graph (
173- original_gm ,
190+ gm ,
174191 split_positions = [start_node_idx , end_node_idx ],
175192 submodule_hook = submodule_hook ,
176193 submodule_name_prefix = submodule_name ,
@@ -186,7 +203,7 @@ class NodeProducedOrConsumedCountCtx:
186203
187204
188205def _get_submodule_inputs_and_outputs (
189- original_gm : torch .fx .GraphModule ,
206+ gm : torch .fx .GraphModule ,
190207 start_node_idx : int ,
191208 end_node_idx : int ,
192209 chain_style = False ,
@@ -196,7 +213,7 @@ def _get_submodule_inputs_and_outputs(
196213 defaultdict (int ),
197214 defaultdict (int ),
198215 )
199- node_list = list (original_gm .graph .nodes )
216+ node_list = list (gm .graph .nodes )
200217
201218 def get_related_node (node ):
202219 for arg in node .args :
@@ -240,7 +257,7 @@ def get_related_node(node):
240257 if count_ctx .node2before_input [node ] > 0
241258 if count_ctx .node2body [node ] == 0
242259 if count_ctx .node2after_output [node ] > 0
243- ][: 1 ]
260+ ]
244261 input_nodes_set = set (input_nodes )
245262 input_nodes = [
246263 * input_nodes ,
@@ -251,5 +268,4 @@ def get_related_node(node):
251268 * output_nodes ,
252269 * [node for node in identity_nodes if node not in output_nodes_set ],
253270 ]
254-
255271 return input_nodes , output_nodes , identity_nodes
0 commit comments