Skip to content

Commit 82d514e

Browse files
committed
Implement non_batch_call_method_view_pass.
1 parent 5068a8b commit 82d514e

File tree

7 files changed

+177
-7
lines changed

7 files changed

+177
-7
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import operator
2+
from collections import defaultdict, deque
3+
import torch.fx as fx
4+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
5+
import os
6+
7+
8+
class ConcretePass(DimensionGeneralizationPass):
9+
def __init__(self, *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
12+
def get_pass_name(cls) -> bool:
13+
return os.path.basename(__file__)[:-3]
14+
15+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
16+
if 0 in self.axes:
17+
return False
18+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
19+
20+
def _node_need_rewrite(self, node) -> bool:
21+
if not (node.op == "call_method"):
22+
return False
23+
if not (node.target == "view"):
24+
return False
25+
if not (len(node.args) >= 2):
26+
return False
27+
if not (isinstance(node.args[1], int)):
28+
return False
29+
if -1 in node.args[1:]:
30+
return False
31+
return True
32+
33+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
34+
"""
35+
Fx Pass: Replaces hardcoded constants in 'view' ops that match an input tensor dimension
36+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size (axis 0).
37+
"""
38+
# Create a new graph to hold the rewritten nodes
39+
new_graph = fx.Graph()
40+
41+
# Create a map to link nodes from the old graph to nodes in the new graph
42+
val_map = {}
43+
44+
def get_index_map_of_common_dim(input_shape, view_args):
45+
dim2input_indices = defaultdict(deque)
46+
for input_index, dim in enumerate(input_shape):
47+
dim2input_indices[dim].append(input_index)
48+
49+
# arg_index: input_index
50+
common_arg_index2input_index = {}
51+
for arg_index, arg in enumerate(view_args):
52+
if arg in dim2input_indices.keys() and dim2input_indices[arg]:
53+
input_index = dim2input_indices[arg].popleft()
54+
common_arg_index2input_index[arg_index] = input_index
55+
return common_arg_index2input_index
56+
57+
def get_new_tuple_args(input_shape, view_args):
58+
common_arg_index2input_index = get_index_map_of_common_dim(
59+
input_shape, view_args
60+
)
61+
rest_view_indices = list(
62+
set(range(len(view_args))) - set(common_arg_index2input_index.keys())
63+
)
64+
rest_input_indices = list(
65+
set(range(len(input_shape)))
66+
- set(common_arg_index2input_index.values())
67+
)
68+
69+
new_view_args_dict = {}
70+
for arg_index, input_index in common_arg_index2input_index.items():
71+
if arg_index == 0:
72+
new_view_args_dict[arg_index] = view_args[arg_index]
73+
else:
74+
new_input_node = val_map[input_tensor_node]
75+
size_node = new_graph.call_method(
76+
"size", args=(new_input_node, input_index)
77+
)
78+
new_view_args_dict[arg_index] = size_node
79+
80+
size_nodes = []
81+
for input_index in sorted(rest_input_indices):
82+
new_input_node = val_map[input_tensor_node]
83+
size_nodes.append(
84+
new_graph.call_method("size", args=(new_input_node, input_index))
85+
)
86+
87+
if len(rest_view_indices) == 1 and len(rest_input_indices) > 1:
88+
# Merge the reset input dims into 1.
89+
# e.g. input_shape=[1, 226, 4, 8], view_args=[1, 226, 32]
90+
mul_node = new_graph.call_function(
91+
operator.mul, args=(size_nodes[0], size_nodes[1])
92+
)
93+
for i in range(2, len(size_nodes)):
94+
mul_node = new_graph.call_function(
95+
operator.mul, args=(mul_node, size_nodes[i])
96+
)
97+
new_view_args_dict[rest_view_indices[0]] = mul_node
98+
elif (
99+
len(rest_input_indices) == 1
100+
and len(rest_view_indices) == 2
101+
and view_args[rest_view_indices[0]] == view_args[rest_view_indices[1]]
102+
):
103+
# Factorize the input dim with sqrt.
104+
# e.g. input_shape=[1, 9216, 128], view_args=[1, 96, 96, 128]
105+
pow_node = new_graph.call_function(
106+
operator.pow, args=(size_nodes[0], 0.5)
107+
)
108+
int_node = new_graph.call_function(int, args=(pow_node,))
109+
for arg_index in rest_view_indices:
110+
new_view_args_dict[arg_index] = int_node
111+
else:
112+
print(f"Not Support rewriting for {input_shape=}, {view_args=}")
113+
for arg_index in rest_view_indices:
114+
new_view_args_dict[arg_index] = view_args[arg_index]
115+
116+
new_view_args = dict(sorted(new_view_args_dict.items())).values()
117+
return tuple(new_view_args)
118+
119+
for node in traced_module.graph.nodes:
120+
if self._node_need_rewrite(node):
121+
# Get the input tensor node
122+
input_tensor_node = node.args[0]
123+
124+
# --- Dependency on ShapeProp Results ---
125+
# input_shape is the static shape (e.g., batch_size, C, H, W)
126+
input_meta = input_tensor_node.meta.get("tensor_meta")
127+
if input_meta is None:
128+
raise RuntimeError(
129+
f"Node {input_tensor_node.name} lacks tensor_meta. Did ShapeProp run?"
130+
)
131+
132+
# Get the target shape arguments for view (e.g., 1, -1, 6, 64)
133+
input_shape = input_tensor_node.meta["tensor_meta"].shape
134+
view_args = node.args[1:]
135+
new_view_args = get_new_tuple_args(input_shape, view_args)
136+
137+
# --- Rebuild the view node ---
138+
# 1. Map the input tensor node to the new graph node
139+
new_input_node = val_map[input_tensor_node]
140+
141+
# 2. Insert the new view node into the new graph
142+
# with new_graph.inserting_after(new_input_node):
143+
new_node = new_graph.call_method(
144+
"view", args=(new_input_node, *new_view_args)
145+
)
146+
147+
# 3. Map the old node to the new node
148+
val_map[node] = new_node
149+
150+
else:
151+
# Copy other nodes to the new graph
152+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
153+
val_map[node] = new_node
154+
155+
# Replace the old graph with the new graph and return
156+
traced_module.graph = new_graph
157+
traced_module.recompile()
158+
return traced_module

graph_net/torch/static_to_dynamic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def make_config(self, pass_names=()):
4646
"naive_call_method_reshape_pass",
4747
"naive_call_method_expand_pass",
4848
"non_batch_call_method_expand_pass",
49+
"non_batch_call_method_view_pass",
4950
"non_batch_call_function_arange_pass", # typos: skip
5051
"non_batch_call_function_getitem_slice_pass",
5152
"non_batch_call_function_full_pass",

samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Data2VecVisionModel/graph_net.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"naive_call_method_reshape_pass",
2222
"naive_call_method_expand_pass",
2323
"non_batch_call_method_expand_pass",
24+
"non_batch_call_method_view_pass",
2425
"non_batch_call_function_arange_pass",
2526
"non_batch_call_function_getitem_slice_pass",
2627
"non_batch_call_function_full_pass",
@@ -29,4 +30,4 @@
2930
"non_batch_call_function_arange_plus_one_pass"
3031
],
3132
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
32-
}
33+
}

samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification/graph_net.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@
1515
"region:us"
1616
],
1717
"heuristic_tag": "computer_vision",
18+
"dimension_generalization_passes": [
19+
"non_batch_call_method_view_pass"
20+
],
1821
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
19-
}
22+
}

samples/transformers-auto-model/microsoft_swin-base-patch4-window12-384-in22k/graph_net.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@
2020
"region:us"
2121
],
2222
"heuristic_tag": "computer_vision",
23+
"dimension_generalization_passes": [
24+
"non_batch_call_method_view_pass"
25+
],
2326
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
24-
}
27+
}

samples/ultralytics/yolov3-tinyu/graph_net.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
"dynamic": false,
66
"source": "ultralytics",
77
"heuristic_tag": "computer_vision",
8-
"dimension_generalization_passes": [],
8+
"dimension_generalization_passes": [
9+
"non_batch_call_method_view_pass"
10+
],
911
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
10-
}
12+
}

samples/ultralytics/yolov9c/graph_net.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
"dynamic": false,
66
"source": "ultralytics",
77
"heuristic_tag": "computer_vision",
8-
"dimension_generalization_passes": [],
8+
"dimension_generalization_passes": [
9+
"non_batch_call_method_view_pass"
10+
],
911
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
10-
}
12+
}

0 commit comments

Comments
 (0)