@@ -72,16 +72,11 @@ def get_end_node_idx(range_idx):
7272 return i + 1
7373 raise NotImplementedError ("Dead code." )
7474
75- def print_submodule_call (prompt , gm ):
76- submodule_call_stmts = [
77- stmt for stmt in gm .code .split ("\n " ) if "self.extracted_submodule" in stmt
78- ]
79- print (f"{ prompt } " , submodule_call_stmts )
80-
8175 for range_idx in range (len (range_idx2submodule_body_nodes )):
8276 (
8377 submodule_input_nodes ,
8478 submodule_output_nodes ,
79+ identity_nodes ,
8580 ) = _get_submodule_inputs_and_outputs (
8681 original_gm = original_gm ,
8782 start_node_idx = get_start_node_idx (range_idx ),
@@ -141,20 +136,17 @@ def get_output_nodes(range_idx):
141136 prev_node = new_output_node
142137
143138 # Replace all use of outputs
139+ identity_node_set = set (identity_nodes )
144140 for original_output in get_output_nodes (range_idx ):
145- original_output .replace_all_uses_with (node_map [original_output ])
141+ if original_output not in identity_node_set :
142+ original_output .replace_all_uses_with (node_map [original_output ])
146143
147144 # Erase old nodes
148145 for node in reversed (get_body_nodes (range_idx )):
149146 original_gm .graph .erase_node (node )
150- # print_submodule_call("(fx) after Erase old nodes", original_gm)
151-
152- # print_submodule_call("(fx) before recompile", original_gm)
153147
154148 original_gm .recompile ()
155149
156- # print_submodule_call("(fx) after recompile", original_gm)
157-
158150 return original_gm
159151
160152
@@ -215,33 +207,38 @@ def get_related_node(node):
215207 for related_node in get_related_node (node ):
216208 count_ctx .node2after_output [related_node ] += 1
217209
218- if chain_style :
219- input_nodes = [
220- node
221- for node in node_list
222- if (count_ctx .node2before_input [node ] > 0 )
223- if (count_ctx .node2body [node ] > 0 or count_ctx .node2after_output [node ] > 0 )
224- ]
225- input_nodes_set = set (input_nodes )
226- output_nodes = [
227- node
228- for node in node_list
229- if (count_ctx .node2before_input [node ] > 0 or count_ctx .node2body [node ] > 0 )
230- if (count_ctx .node2after_output [node ] > 0 )
231- ]
210+ input_nodes = [
211+ node
212+ for node in node_list
213+ if count_ctx .node2before_input [node ] > 0
214+ if count_ctx .node2body [node ] > 0
215+ ]
216+ output_nodes = [
217+ node
218+ for node in node_list
219+ if not (count_ctx .node2before_input [node ] > 0 )
220+ if count_ctx .node2body [node ] > 0
221+ if count_ctx .node2after_output [node ] > 0
222+ ]
223+ if not chain_style :
224+ identity_nodes = []
232225 else :
233- input_nodes = [
226+ identity_nodes = [
234227 node
235228 for node in node_list
236229 if count_ctx .node2before_input [node ] > 0
237- if count_ctx .node2body [node ] > 0
230+ if count_ctx .node2body [node ] == 0
231+ if count_ctx .node2after_output [node ] > 0
232+ ][:1 ]
233+ input_nodes_set = set (input_nodes )
234+ input_nodes = [
235+ * input_nodes ,
236+ * [node for node in identity_nodes if node not in input_nodes_set ],
238237 ]
238+ output_nodes_set = set (output_nodes )
239239 output_nodes = [
240- node
241- for node in node_list
242- if not (count_ctx .node2before_input [node ] > 0 )
243- if count_ctx .node2body [node ] > 0
244- if count_ctx .node2after_output [node ] > 0
240+ * output_nodes ,
241+ * [node for node in identity_nodes if node not in output_nodes_set ],
245242 ]
246243
247- return input_nodes , output_nodes
244+ return input_nodes , output_nodes , identity_nodes
0 commit comments