diff --git a/graph_net/torch/fx_graph_parse_util.py b/graph_net/torch/fx_graph_parse_util.py index e18f33e3f..d4741a750 100755 --- a/graph_net/torch/fx_graph_parse_util.py +++ b/graph_net/torch/fx_graph_parse_util.py @@ -174,18 +174,14 @@ def get_diff_input_names(): if name not in placeholder_names ] - if len(inputs) == len(traced_sample_inputs) + 1: - diff_input_names = get_diff_input_names() - assert len(diff_input_names) == 1, f"{diff_input_names=}" - pos, name = diff_input_names[0] - for i, node in enumerate(traced_module.graph.nodes): - if i < pos: - assert node.op == "placeholder" - elif i == pos: - with traced_module.graph.inserting_before(node): - traced_module.graph.placeholder(name) - else: - break + diff_input_names = get_diff_input_names() + if len(diff_input_names) > 0: + first_node = next(iter(traced_module.graph.nodes)) + + with traced_module.graph.inserting_before(first_node): + for _, name in diff_input_names: + traced_module.graph.placeholder(name) + traced_module.recompile() def get_zip_filter_names():