@@ -74,6 +74,7 @@ def call_method_predicator(method_name):
7474 return [
7575 (call_method_predicator ("view" ), dynamic_view_rewrite_pass ),
7676 (call_method_predicator ("reshape" ), dynamic_reshape_rewrite_pass ),
77+ (call_method_predicator ("expand" ), dynamic_expand_rewrite_pass ),
7778 ]
7879
7980 def forward (self , * args , ** kwargs ):
@@ -87,6 +88,13 @@ def forward(self, *args, **kwargs):
8788 # return traced_module(*args, **kwargs)
8889
8990
91+ def has_call_method (traced_module : fx .GraphModule , method_name ) -> bool :
92+ for node in traced_module .graph .nodes :
93+ if node .op == "call_method" and node .target == method_name :
94+ return True
95+ return False
96+
97+
9098def dynamic_view_rewrite_pass (traced_module : fx .GraphModule ) -> fx .GraphModule :
9199 """
92100 Fx Pass: Replaces hardcoded constants in 'view' ops that match an input tensor dimension
@@ -141,12 +149,15 @@ def dynamic_view_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModule:
141149 # and this dimension is one we wish to generalize (e.g., batch size, axis 0)
142150 if target_dim == input_dim_size :
143151 # Prioritize matching the batch size (axis 0)
144- if axis_idx == 0 :
145- matched_axis = 0
152+ if i == 0 and axis_idx == 0 :
153+ matched_axis = axis_idx
146154 break
147- # If not axis 0, but still a match, consider generalization
148- elif matched_axis == - 1 :
155+ elif i > 0 and axis_idx > 0 and input_dim_size > 1 :
149156 matched_axis = axis_idx
157+ break
158+ else :
159+ # Do nothing.
160+ pass
150161
151162 if matched_axis != - 1 :
152163 # Found a matching dynamic axis (matched_axis), replace it with a size() call
@@ -242,12 +253,15 @@ def dynamic_reshape_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModul
242253 # and this dimension is one we wish to generalize (e.g., batch size, axis 0)
243254 if target_dim == input_dim_size :
244255 # Prioritize matching the batch size (axis 0)
245- if axis_idx == 0 :
246- matched_axis = 0
256+ if i == 0 and axis_idx == 0 :
257+ matched_axis = axis_idx
247258 break
248- # If not axis 0, but still a match, consider generalization
249- elif matched_axis == - 1 :
259+ elif i > 0 and axis_idx > 0 and input_dim_size > 1 :
250260 matched_axis = axis_idx
261+ break
262+ else :
263+ # Do nothing.
264+ pass
251265
252266 if matched_axis != - 1 :
253267 # Found a matching dynamic axis (matched_axis), replace it with a size() call
@@ -289,8 +303,100 @@ def dynamic_reshape_rewrite_pass(traced_module: fx.GraphModule) -> fx.GraphModul
289303 return traced_module
290304
291305
292- def has_call_method (traced_module : fx .GraphModule , method_name ) -> bool :
306+ def dynamic_expand_rewrite_pass (traced_module : fx .GraphModule ) -> fx .GraphModule :
307+ """
308+ Fx Pass: Replaces hardcoded constants in 'expand' ops that match an input tensor dimension
309+ with a dynamic 'size()' call. The primary goal is to dynamicize the batch size (axis 0).
310+ """
311+ # Create a new graph to hold the rewritten nodes
312+ new_graph = fx .Graph ()
313+
314+ # Create a map to link nodes from the old graph to nodes in the new graph
315+ val_map = {}
316+
293317 for node in traced_module .graph .nodes :
294- if node .op == "call_method" and node .target == method_name :
295- return True
296- return False
318+ if node .op == "call_method" and node .target == "expand" :
319+ # Get the input tensor node
320+ input_tensor_node = node .args [0 ]
321+ # Get the target shape arguments for expand (e.g., 1, 4, 6, 64)
322+ expand_args = node .args [1 :]
323+
324+ # --- Dependency on ShapeProp Results ---
325+ # input_shape is the static shape (e.g., batch_size, C, H, W)
326+ input_meta = input_tensor_node .meta .get ("tensor_meta" )
327+ if input_meta is None :
328+ raise RuntimeError (
329+ f"Node { input_tensor_node .name } lacks tensor_meta. Did ShapeProp run?"
330+ )
331+
332+ input_shape = input_meta .shape
333+
334+ # Find the new list of expand arguments
335+ new_expand_args = []
336+
337+ # Iterate over the target dimensions of expand (dim0, dim1, ...)
338+ for i , target_dim in enumerate (expand_args ):
339+ # 1. Handle dynamic dimensions (e.g., -1 or non-integer values)
340+ if not isinstance (target_dim , int ) or target_dim < 1 :
341+ new_expand_args .append (
342+ val_map [target_dim ] if target_dim in val_map else target_dim
343+ )
344+ continue
345+
346+ # 2. Handle hardcoded constants (e.g., 1, 6, 64)
347+
348+ # --- Core Logic: Find the matching dynamic axis ---
349+
350+ # Default: Keep the hardcoded constant if no matching dynamic axis is found
351+ best_match = target_dim
352+ matched_axis = - 1
353+
354+ axis_idx = i
355+ input_dim_size = input_shape [i ]
356+ if target_dim == input_dim_size :
357+ if axis_idx == 0 :
358+ matched_axis = axis_idx
359+ elif axis_idx > 0 and input_dim_size > 1 :
360+ matched_axis = axis_idx
361+ else :
362+ # Do nothing.
363+ pass
364+
365+ if matched_axis != - 1 :
366+ # Found a matching dynamic axis (matched_axis), replace it with a size() call
367+
368+ # 1. Create a call to size(axis) in the new graph
369+ # NOTE: input_tensor_node must first be mapped to a new graph node via val_map
370+ new_input_node = val_map [input_tensor_node ]
371+
372+ # Use the size() method to retrieve the dynamic dimension
373+ size_node = new_graph .call_method (
374+ "size" , args = (new_input_node , matched_axis )
375+ )
376+
377+ best_match = size_node
378+
379+ new_expand_args .append (best_match )
380+
381+ # --- Rebuild the expand node ---
382+ # 1. Map the input tensor node to the new graph node
383+ new_input_node = val_map [input_tensor_node ]
384+
385+ # 2. Insert the new expand node into the new graph
386+ # with new_graph.inserting_after(new_input_node):
387+ new_node = new_graph .call_method (
388+ "expand" , args = (new_input_node , * new_expand_args )
389+ )
390+
391+ # 3. Map the old node to the new node
392+ val_map [node ] = new_node
393+
394+ else :
395+ # Copy other nodes to the new graph
396+ new_node = new_graph .node_copy (node , lambda x : val_map [x ])
397+ val_map [node ] = new_node
398+
399+ # Replace the old graph with the new graph and return
400+ traced_module .graph = new_graph
401+ traced_module .recompile ()
402+ return traced_module
0 commit comments