@@ -85,7 +85,8 @@ def print_submodule_call(prompt, gm):
8585 def sort_key (node ):
8686 return new_node2original_node [node ].name
8787
88- for range_idx in range (len (range_idx2submodule_body_nodes )):
88+ num_subgraphs = len (range_idx2submodule_body_nodes )
89+ for range_idx in range (num_subgraphs ):
8990 (
9091 submodule_input_nodes ,
9192 submodule_output_nodes ,
@@ -96,6 +97,7 @@ def sort_key(node):
9697 end_node_idx = get_end_node_idx (range_idx ),
9798 chain_style = chain_style ,
9899 )
100+ identity_node_set = set (identity_nodes )
99101
100102 def get_input_nodes (range_idx ):
101103 return sorted (submodule_input_nodes , key = sort_key )
@@ -153,13 +155,13 @@ def get_output_nodes(range_idx):
153155 prev_node = new_output_node
154156
155157 # Replace all use of outputs
156- identity_node_set = set (identity_nodes )
157158 for original_output in get_output_nodes (range_idx ):
158- if original_output not in identity_node_set :
159- original_output .replace_all_uses_with (node_map [original_output ])
160- new_node2original_node [
161- node_map [original_output ]
162- ] = new_node2original_node [original_output ]
159+ if original_output in identity_node_set :
160+ continue
161+ original_output .replace_all_uses_with (node_map [original_output ])
162+ new_node2original_node [node_map [original_output ]] = new_node2original_node [
163+ original_output
164+ ]
163165
164166 # Erase old nodes
165167 for node in reversed (get_body_nodes (range_idx )):
@@ -215,12 +217,18 @@ def _get_submodule_inputs_and_outputs(
215217 return minimal_input_nodes , minimal_output_nodes , []
216218 else :
217219 node_list = list (gm .graph .nodes )
218- input_nodes , _ = _get_minimal_submodule_inputs_and_outputs (
219- gm = gm , start_node_idx = start_node_idx , end_node_idx = len (node_list )
220- )
221- output_nodes , _ = _get_minimal_submodule_inputs_and_outputs (
222- gm = gm , start_node_idx = end_node_idx , end_node_idx = len (node_list )
223- )
220+ if _is_node_idx_out_of_range (gm , start_node_idx ):
221+ input_nodes = list (_get_return_nodes (gm ))
222+ else :
223+ input_nodes , _ = _get_minimal_submodule_inputs_and_outputs (
224+ gm = gm , start_node_idx = start_node_idx , end_node_idx = len (node_list )
225+ )
226+ if _is_node_idx_out_of_range (gm , end_node_idx ):
227+ output_nodes = list (_get_return_nodes (gm ))
228+ else :
229+ output_nodes , _ = _get_minimal_submodule_inputs_and_outputs (
230+ gm = gm , start_node_idx = end_node_idx , end_node_idx = len (node_list )
231+ )
224232 identity_nodes_set = set (input_nodes ) & set (output_nodes )
225233 identity_nodes = [node for node in input_nodes if node in identity_nodes_set ]
226234 return input_nodes , output_nodes , identity_nodes
@@ -275,25 +283,19 @@ def get_args_node_and_self_node(node):
275283 for related_node in get_args_node_and_self_node (node ):
276284 count_ctx .node2after_output [related_node ] += 1
277285
278- if _is_node_idx_out_of_range (gm , start_node_idx ):
279- input_nodes = list (_get_return_nodes (gm ))
280- else :
281- input_nodes = [
282- node
283- for node in node_list
284- if count_ctx .node2before_input [node ] > 0
285- if count_ctx .node2body [node ] > 0
286- ]
287- if _is_node_idx_out_of_range (gm , end_node_idx ):
288- output_nodes = list (_get_return_nodes (gm ))
289- else :
290- output_nodes = [
291- node
292- for node in node_list
293- if not (count_ctx .node2before_input [node ] > 0 )
294- if count_ctx .node2body [node ] > 0
295- if count_ctx .node2after_output [node ] > 0
296- ]
286+ input_nodes = [
287+ node
288+ for node in node_list
289+ if count_ctx .node2before_input [node ] > 0
290+ if count_ctx .node2body [node ] > 0
291+ ]
292+ output_nodes = [
293+ node
294+ for node in node_list
295+ if not (count_ctx .node2before_input [node ] > 0 )
296+ if count_ctx .node2body [node ] > 0
297+ if count_ctx .node2after_output [node ] > 0
298+ ]
297299 return input_nodes , output_nodes
298300
299301
0 commit comments