Skip to content

Commit 4582a6f

Browse files
authored
support dimension generalization for torch.Tensor.expand() (#383)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints * Introduce input_tensor_constraints.py using shape propagation logic. * support dimension generalization for torch.Tensor.view and torch.Tensor.reshape * 1) support dimension generalization for torch.Tensor.expand(); 2) fix bugs in generalization for torch.Tensor.view and torch.Tensor.reshape
1 parent 54c2597 commit 4582a6f

File tree

2 files changed

+120
-14
lines changed

2 files changed

+120
-14
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
samples/timm/resnetaa50d.d_in12k
1+
# samples/timm/resnetaa50d.d_in12k
22
# samples/transformers-auto-model/opus-mt-en-gmw
3-
# samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation
3+
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/torch/static_to_dynamic.py

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def call_method_predicator(method_name):
7474
return [
7575
(call_method_predicator("view"), dynamic_view_rewrite_pass),
7676
(call_method_predicator("reshape"), dynamic_reshape_rewrite_pass),
77+
(call_method_predicator("expand"), dynamic_expand_rewrite_pass),
7778
]
7879

7980
def forward(self, *args, **kwargs):
@@ -87,6 +88,13 @@ def forward(self, *args, **kwargs):
8788
# return traced_module(*args, **kwargs)
8889

8990

91+
def has_call_method(traced_module: fx.GraphModule, method_name) -> bool:
92+
for node in traced_module.graph.nodes:
93+
if node.op == "call_method" and node.target == method_name:
94+
return True
95+
return False
96+
97+
9098
def dynamic_view_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModule:
9199
"""
92100
Fx Pass: Replaces hardcoded constants in 'view' ops that match an input tensor dimension
@@ -141,12 +149,15 @@ def dynamic_view_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModule:
141149
# and this dimension is one we wish to generalize (e.g., batch size, axis 0)
142150
if target_dim == input_dim_size:
143151
# Prioritize matching the batch size (axis 0)
144-
if axis_idx == 0:
145-
matched_axis = 0
152+
if i == 0 and axis_idx == 0:
153+
matched_axis = axis_idx
146154
break
147-
# If not axis 0, but still a match, consider generalization
148-
elif matched_axis == -1:
155+
elif i > 0 and axis_idx > 0 and input_dim_size > 1:
149156
matched_axis = axis_idx
157+
break
158+
else:
159+
# Do nothing.
160+
pass
150161

151162
if matched_axis != -1:
152163
# Found a matching dynamic axis (matched_axis), replace it with a size() call
@@ -242,12 +253,15 @@ def dynamic_reshape_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModul
242253
# and this dimension is one we wish to generalize (e.g., batch size, axis 0)
243254
if target_dim == input_dim_size:
244255
# Prioritize matching the batch size (axis 0)
245-
if axis_idx == 0:
246-
matched_axis = 0
256+
if i == 0 and axis_idx == 0:
257+
matched_axis = axis_idx
247258
break
248-
# If not axis 0, but still a match, consider generalization
249-
elif matched_axis == -1:
259+
elif i > 0 and axis_idx > 0 and input_dim_size > 1:
250260
matched_axis = axis_idx
261+
break
262+
else:
263+
# Do nothing.
264+
pass
251265

252266
if matched_axis != -1:
253267
# Found a matching dynamic axis (matched_axis), replace it with a size() call
@@ -289,8 +303,100 @@ def dynamic_reshape_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModul
289303
return traced_module
290304

291305

292-
def has_call_method(traced_module: fx.GraphModule, method_name) -> bool:
306+
def dynamic_expand_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModule:
307+
"""
308+
Fx Pass: Replaces hardcoded constants in 'expand' ops that match an input tensor dimension
309+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size (axis 0).
310+
"""
311+
# Create a new graph to hold the rewritten nodes
312+
new_graph = fx.Graph()
313+
314+
# Create a map to link nodes from the old graph to nodes in the new graph
315+
val_map = {}
316+
293317
for node in traced_module.graph.nodes:
294-
if node.op == "call_method" and node.target == method_name:
295-
return True
296-
return False
318+
if node.op == "call_method" and node.target == "expand":
319+
# Get the input tensor node
320+
input_tensor_node = node.args[0]
321+
# Get the target shape arguments for expand (e.g., 1, 4, 6, 64)
322+
expand_args = node.args[1:]
323+
324+
# --- Dependency on ShapeProp Results ---
325+
# input_shape is the static shape (e.g., batch_size, C, H, W)
326+
input_meta = input_tensor_node.meta.get("tensor_meta")
327+
if input_meta is None:
328+
raise RuntimeError(
329+
f"Node {input_tensor_node.name} lacks tensor_meta. Did ShapeProp run?"
330+
)
331+
332+
input_shape = input_meta.shape
333+
334+
# Find the new list of expand arguments
335+
new_expand_args = []
336+
337+
# Iterate over the target dimensions of expand (dim0, dim1, ...)
338+
for i, target_dim in enumerate(expand_args):
339+
# 1. Handle dynamic dimensions (e.g., -1 or non-integer values)
340+
if not isinstance(target_dim, int) or target_dim < 1:
341+
new_expand_args.append(
342+
val_map[target_dim] if target_dim in val_map else target_dim
343+
)
344+
continue
345+
346+
# 2. Handle hardcoded constants (e.g., 1, 6, 64)
347+
348+
# --- Core Logic: Find the matching dynamic axis ---
349+
350+
# Default: Keep the hardcoded constant if no matching dynamic axis is found
351+
best_match = target_dim
352+
matched_axis = -1
353+
354+
axis_idx = i
355+
input_dim_size = input_shape[i]
356+
if target_dim == input_dim_size:
357+
if axis_idx == 0:
358+
matched_axis = axis_idx
359+
elif axis_idx > 0 and input_dim_size > 1:
360+
matched_axis = axis_idx
361+
else:
362+
# Do nothing.
363+
pass
364+
365+
if matched_axis != -1:
366+
# Found a matching dynamic axis (matched_axis), replace it with a size() call
367+
368+
# 1. Create a call to size(axis) in the new graph
369+
# NOTE: input_tensor_node must first be mapped to a new graph node via val_map
370+
new_input_node = val_map[input_tensor_node]
371+
372+
# Use the size() method to retrieve the dynamic dimension
373+
size_node = new_graph.call_method(
374+
"size", args=(new_input_node, matched_axis)
375+
)
376+
377+
best_match = size_node
378+
379+
new_expand_args.append(best_match)
380+
381+
# --- Rebuild the expand node ---
382+
# 1. Map the input tensor node to the new graph node
383+
new_input_node = val_map[input_tensor_node]
384+
385+
# 2. Insert the new expand node into the new graph
386+
# with new_graph.inserting_after(new_input_node):
387+
new_node = new_graph.call_method(
388+
"expand", args=(new_input_node, *new_expand_args)
389+
)
390+
391+
# 3. Map the old node to the new node
392+
val_map[node] = new_node
393+
394+
else:
395+
# Copy other nodes to the new graph
396+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
397+
val_map[node] = new_node
398+
399+
# Replace the old graph with the new graph and return
400+
traced_module.graph = new_graph
401+
traced_module.recompile()
402+
return traced_module

0 commit comments

Comments
 (0)