@@ -82,6 +82,7 @@ def print_submodule_call(prompt, gm):
8282 (
8383 submodule_input_nodes ,
8484 submodule_output_nodes ,
85+ identity_nodes ,
8586 ) = _get_submodule_inputs_and_outputs (
8687 original_gm = original_gm ,
8788 start_node_idx = get_start_node_idx (range_idx ),
@@ -141,8 +142,10 @@ def get_output_nodes(range_idx):
141142 prev_node = new_output_node
142143
143144 # Replace all use of outputs
145+ identity_node_set = set (identity_nodes )
144146 for original_output in get_output_nodes (range_idx ):
145- original_output .replace_all_uses_with (node_map [original_output ])
147+ if original_output not in identity_node_set :
148+ original_output .replace_all_uses_with (node_map [original_output ])
146149
147150 # Erase old nodes
148151 for node in reversed (get_body_nodes (range_idx )):
@@ -215,33 +218,38 @@ def get_related_node(node):
215218 for related_node in get_related_node (node ):
216219 count_ctx .node2after_output [related_node ] += 1
217220
218- if chain_style :
219- input_nodes = [
220- node
221- for node in node_list
222- if (count_ctx .node2before_input [node ] > 0 )
223- if (count_ctx .node2body [node ] > 0 or count_ctx .node2after_output [node ] > 0 )
224- ]
225- input_nodes_set = set (input_nodes )
226- output_nodes = [
227- node
228- for node in node_list
229- if (count_ctx .node2before_input [node ] > 0 or count_ctx .node2body [node ] > 0 )
230- if (count_ctx .node2after_output [node ] > 0 )
231- ]
221+ input_nodes = [
222+ node
223+ for node in node_list
224+ if count_ctx .node2before_input [node ] > 0
225+ if count_ctx .node2body [node ] > 0
226+ ]
227+ output_nodes = [
228+ node
229+ for node in node_list
230+ if not (count_ctx .node2before_input [node ] > 0 )
231+ if count_ctx .node2body [node ] > 0
232+ if count_ctx .node2after_output [node ] > 0
233+ ]
234+ if not chain_style :
235+ identity_nodes = []
232236 else :
233- input_nodes = [
237+ identity_nodes = [
234238 node
235239 for node in node_list
236240 if count_ctx .node2before_input [node ] > 0
237- if count_ctx .node2body [node ] > 0
241+ if count_ctx .node2body [node ] == 0
242+ if count_ctx .node2after_output [node ] > 0
243+ ][:1 ]
244+ input_nodes_set = set (input_nodes )
245+ input_nodes = [
246+ * input_nodes ,
247+ * [node for node in identity_nodes if node not in input_nodes_set ],
238248 ]
249+ output_nodes_set = set (output_nodes )
239250 output_nodes = [
240- node
241- for node in node_list
242- if not (count_ctx .node2before_input [node ] > 0 )
243- if count_ctx .node2body [node ] > 0
244- if count_ctx .node2after_output [node ] > 0
251+ * output_nodes ,
252+ * [node for node in identity_nodes if node not in output_nodes_set ],
245253 ]
246254
247- return input_nodes , output_nodes
255+ return input_nodes , output_nodes , identity_nodes
0 commit comments