diff --git a/graph_net/torch/fx_graph_parse_util.py b/graph_net/torch/fx_graph_parse_util.py index d4741a750..0f41f5eed 100755 --- a/graph_net/torch/fx_graph_parse_util.py +++ b/graph_net/torch/fx_graph_parse_util.py @@ -166,24 +166,6 @@ def handle_placeholder_name(pattern2replacement): handle_placeholder_name(pattern2replacement) - def get_diff_input_names(): - placeholder_names = set(get_input_names_from_placeholder()) - return [ - (i, name) - for i, name in enumerate(get_input_names_from_signature()) - if name not in placeholder_names - ] - - diff_input_names = get_diff_input_names() - if len(diff_input_names) > 0: - first_node = next(iter(traced_module.graph.nodes)) - - with traced_module.graph.inserting_before(first_node): - for _, name in diff_input_names: - traced_module.graph.placeholder(name) - - traced_module.recompile() - def get_zip_filter_names(): names_from_signature = get_input_names_from_signature() names_from_placeholder = get_input_names_from_placeholder() @@ -198,61 +180,44 @@ def get_zip_filter_names(): ) def handle_underscore_suffix_difference(): - zip_filter_names = get_zip_filter_names() - if not (len(zip_filter_names) > 0): - return - names = set( - name_in_placeholder - for _0, name_in_signature, name_in_placeholder in zip_filter_names - if (f"{name_in_signature}_" == name_in_placeholder) - ) - names_in_signature = set( - name_in_signature - for _0, name_in_signature, name_in_placeholder in zip_filter_names - ) - underscore_suffixed_names_in_signature = set( - f"{name_in_signature}_" - for _0, name_in_signature, name_in_placeholder in zip_filter_names - ) - if len(names_in_signature & underscore_suffixed_names_in_signature) == 0: - disordered_names = set( - name_in_placeholder - for _0, name_in_signature, name_in_placeholder in zip_filter_names - if f"{name_in_signature}_" != name_in_placeholder - if name_in_placeholder in underscore_suffixed_names_in_signature - ) - names = names | disordered_names - for node in traced_module.graph.nodes: - if not (node.op == "placeholder"): + ph_nodes = { + node.name: node + for node in traced_module.graph.nodes + if node.op == "placeholder" + } + sig_names = get_input_names_from_signature() + sig_names_set = set(sig_names) + for name in sig_names: + target_ph_name = f"{name}_" + if name in ph_nodes or target_ph_name not in ph_nodes: continue - if node.name not in names: + if target_ph_name in sig_names_set: continue - node.target = node.target[:-1] - node.name = node.name[:-1] + node = ph_nodes[target_ph_name] + node.target = node.name = name traced_module.recompile() handle_underscore_suffix_difference() - def handle_prefix_difference(): - zip_filter_names = get_zip_filter_names() - if not (len(zip_filter_names) > 0): - return - names = set( - name_in_placeholder - for _0, name_in_signature, name_in_placeholder in zip_filter_names - if (f"l_{name_in_signature[2:]}" == name_in_placeholder) - if (name_in_signature == f"L_{name_in_placeholder[2:]}") - ) - for node in traced_module.graph.nodes: - if not (node.op == "placeholder"): - continue - if node.name not in names: - continue - node.target = f"L_{node.target[2:]}" - node.name = f"L_{node.name[2:]}" - traced_module.recompile() + def get_diff_input_names(): + placeholder_names = set(get_input_names_from_placeholder()) + return [ + (i, name) + for i, name in enumerate(get_input_names_from_signature()) + if name not in placeholder_names + ] - handle_prefix_difference() + if len(inputs) > len(traced_sample_inputs): + diff_input_names = get_diff_input_names() + first_node = next(iter(traced_module.graph.nodes)) + for _, name in diff_input_names: + if name.startswith("l_"): + name = "L_" + name[2:] + with traced_module.graph.inserting_before(first_node): + new_node = traced_module.graph.placeholder(name) + new_node.name = name + new_node.target = name + traced_module.recompile() if len(get_zip_filter_names()) > 0 and set(get_input_names_from_signature()) == set( get_input_names_from_placeholder()