Skip to content

Commit cdf9f09

Browse files
authored
[Bug Fix] Fix SymInt issue in subgraph decomposition (#506)
* fix * Add blank line for improved readability
1 parent e20e93f commit cdf9f09

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

graph_net/torch/fx_graph_parse_util.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,14 @@ def get_diff_input_names():
174174
if name not in placeholder_names
175175
]
176176

177-
if len(inputs) == len(traced_sample_inputs) + 1:
178-
diff_input_names = get_diff_input_names()
179-
assert len(diff_input_names) == 1, f"{diff_input_names=}"
180-
pos, name = diff_input_names[0]
181-
for i, node in enumerate(traced_module.graph.nodes):
182-
if i < pos:
183-
assert node.op == "placeholder"
184-
elif i == pos:
185-
with traced_module.graph.inserting_before(node):
186-
traced_module.graph.placeholder(name)
187-
else:
188-
break
177+
diff_input_names = get_diff_input_names()
178+
if len(diff_input_names) > 0:
179+
first_node = next(iter(traced_module.graph.nodes))
180+
181+
with traced_module.graph.inserting_before(first_node):
182+
for _, name in diff_input_names:
183+
traced_module.graph.placeholder(name)
184+
189185
traced_module.recompile()
190186

191187
def get_zip_filter_names():

0 commit comments

Comments
 (0)