Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 31 additions & 66 deletions graph_net/torch/fx_graph_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,6 @@ def handle_placeholder_name(pattern2replacement):

handle_placeholder_name(pattern2replacement)

def get_diff_input_names():
placeholder_names = set(get_input_names_from_placeholder())
return [
(i, name)
for i, name in enumerate(get_input_names_from_signature())
if name not in placeholder_names
]

diff_input_names = get_diff_input_names()
if len(diff_input_names) > 0:
first_node = next(iter(traced_module.graph.nodes))

with traced_module.graph.inserting_before(first_node):
for _, name in diff_input_names:
traced_module.graph.placeholder(name)

traced_module.recompile()

def get_zip_filter_names():
names_from_signature = get_input_names_from_signature()
names_from_placeholder = get_input_names_from_placeholder()
Expand All @@ -198,61 +180,44 @@ def get_zip_filter_names():
)

def handle_underscore_suffix_difference():
zip_filter_names = get_zip_filter_names()
if not (len(zip_filter_names) > 0):
return
names = set(
name_in_placeholder
for _0, name_in_signature, name_in_placeholder in zip_filter_names
if (f"{name_in_signature}_" == name_in_placeholder)
)
names_in_signature = set(
name_in_signature
for _0, name_in_signature, name_in_placeholder in zip_filter_names
)
underscore_suffixed_names_in_signature = set(
f"{name_in_signature}_"
for _0, name_in_signature, name_in_placeholder in zip_filter_names
)
if len(names_in_signature & underscore_suffixed_names_in_signature) == 0:
disordered_names = set(
name_in_placeholder
for _0, name_in_signature, name_in_placeholder in zip_filter_names
if f"{name_in_signature}_" != name_in_placeholder
if name_in_placeholder in underscore_suffixed_names_in_signature
)
names = names | disordered_names
for node in traced_module.graph.nodes:
if not (node.op == "placeholder"):
ph_nodes = {
node.name: node
for node in traced_module.graph.nodes
if node.op == "placeholder"
}
sig_names = get_input_names_from_signature()
sig_names_set = set(sig_names)
for name in sig_names:
target_ph_name = f"{name}_"
if name in ph_nodes or target_ph_name not in ph_nodes:
continue
if node.name not in names:
if target_ph_name in sig_names_set:
continue
node.target = node.target[:-1]
node.name = node.name[:-1]
node = ph_nodes[target_ph_name]
node.target = node.name = name
traced_module.recompile()

handle_underscore_suffix_difference()

def handle_prefix_difference():
zip_filter_names = get_zip_filter_names()
if not (len(zip_filter_names) > 0):
return
names = set(
name_in_placeholder
for _0, name_in_signature, name_in_placeholder in zip_filter_names
if (f"l_{name_in_signature[2:]}" == name_in_placeholder)
if (name_in_signature == f"L_{name_in_placeholder[2:]}")
)
for node in traced_module.graph.nodes:
if not (node.op == "placeholder"):
continue
if node.name not in names:
continue
node.target = f"L_{node.target[2:]}"
node.name = f"L_{node.name[2:]}"
traced_module.recompile()
def get_diff_input_names():
placeholder_names = set(get_input_names_from_placeholder())
return [
(i, name)
for i, name in enumerate(get_input_names_from_signature())
if name not in placeholder_names
]

handle_prefix_difference()
if len(inputs) > len(traced_sample_inputs):
diff_input_names = get_diff_input_names()
first_node = next(iter(traced_module.graph.nodes))
for _, name in diff_input_names:
if name.startswith("l_"):
name = "L_" + name[2:]
with traced_module.graph.inserting_before(first_node):
new_node = traced_module.graph.placeholder(name)
new_node.name = name
new_node.target = name
traced_module.recompile()

if len(get_zip_filter_names()) > 0 and set(get_input_names_from_signature()) == set(
get_input_names_from_placeholder()
Expand Down