Skip to content

Commit 196bc71

Browse files
committed
Implement non_batch_call_method_view_pass.
1 parent 5068a8b commit 196bc71

File tree

7 files changed

+180
-7
lines changed

7 files changed

+180
-7
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import operator
2+
from collections import defaultdict, deque
3+
import torch.fx as fx
4+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
5+
import os
6+
7+
8+
class ConcretePass(DimensionGeneralizationPass):
9+
def __init__(self, *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
12+
def get_pass_name(cls) -> bool:
13+
return os.path.basename(__file__)[:-3]
14+
15+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
16+
print(f"[{__file__}] {self.axes=}")
17+
if 0 in self.axes:
18+
return False
19+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
20+
21+
def _node_need_rewrite(self, node) -> bool:
22+
if not (node.op == "call_method"):
23+
return False
24+
if not (node.target == "view"):
25+
return False
26+
if not (len(node.args) >= 2):
27+
return False
28+
if not (isinstance(node.args[1], int)):
29+
return False
30+
if -1 in node.args[1:]:
31+
return False
32+
return True
33+
34+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
35+
"""
36+
Fx Pass: Replaces hardcoded constants in 'view' ops that match an input tensor dimension
37+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size (axis 0).
38+
"""
39+
# Create a new graph to hold the rewritten nodes
40+
new_graph = fx.Graph()
41+
42+
# Create a map to link nodes from the old graph to nodes in the new graph
43+
val_map = {}
44+
45+
def get_index_map_of_common_dim(input_shape, view_args):
46+
dim2input_indices = defaultdict(deque)
47+
for input_index, dim in enumerate(input_shape):
48+
dim2input_indices[dim].append(input_index)
49+
50+
# arg_index: input_index
51+
common_arg_index2input_index = {}
52+
for arg_index, arg in enumerate(view_args):
53+
if arg in dim2input_indices.keys() and dim2input_indices[arg]:
54+
input_index = dim2input_indices[arg].popleft()
55+
common_arg_index2input_index[arg_index] = input_index
56+
return common_arg_index2input_index
57+
58+
def get_new_tuple_args(input_shape, view_args):
59+
common_arg_index2input_index = get_index_map_of_common_dim(
60+
input_shape, view_args
61+
)
62+
rest_view_indices = list(
63+
set(range(len(view_args))) - set(common_arg_index2input_index.keys())
64+
)
65+
rest_input_indices = list(
66+
set(range(len(input_shape)))
67+
- set(common_arg_index2input_index.values())
68+
)
69+
print(
70+
f"{input_shape=}, {view_args=}, {common_arg_index2input_index=}, {rest_view_indices=}, {rest_input_indices=}"
71+
)
72+
73+
new_view_args_dict = {}
74+
for arg_index, input_index in common_arg_index2input_index.items():
75+
if arg_index == 0:
76+
new_view_args_dict[arg_index] = view_args[arg_index]
77+
else:
78+
new_input_node = val_map[input_tensor_node]
79+
size_node = new_graph.call_method(
80+
"size", args=(new_input_node, input_index)
81+
)
82+
new_view_args_dict[arg_index] = size_node
83+
84+
size_nodes = []
85+
for input_index in sorted(rest_input_indices):
86+
new_input_node = val_map[input_tensor_node]
87+
size_nodes.append(
88+
new_graph.call_method("size", args=(new_input_node, input_index))
89+
)
90+
91+
if len(rest_view_indices) == 1 and len(rest_input_indices) > 1:
92+
# Merge the reset input dims into 1.
93+
# e.g. input_shape=[1, 226, 4, 8], view_args=[1, 226, 32]
94+
mul_node = new_graph.call_function(
95+
operator.mul, args=(size_nodes[0], size_nodes[1])
96+
)
97+
for i in range(2, len(size_nodes)):
98+
mul_node = new_graph.call_function(
99+
operator.mul, args=(mul_node, size_nodes[i])
100+
)
101+
new_view_args_dict[rest_view_indices[0]] = mul_node
102+
elif len(rest_view_indices) == 2 and len(rest_input_indices) == 1:
103+
# Factorize the input dim with sqrt.
104+
# e.g. input_shape=[1, 9216, 128], view_args=[1, 96, 96, 128]
105+
assert (
106+
view_args[rest_view_indices[0]] == view_args[rest_view_indices[1]]
107+
)
108+
pow_node = new_graph.call_function(
109+
operator.pow, args=(size_nodes[0], 0.5)
110+
)
111+
int_node = new_graph.call_function(int, args=(pow_node))
112+
for arg_index in rest_view_indices:
113+
new_view_args_dict[arg_index] = int_node
114+
else:
115+
print(f"Not Support rewriting for {input_shape=}, {view_args=}")
116+
for arg_index in rest_view_indices:
117+
new_view_args_dict[arg_index] = view_args[arg_index]
118+
119+
new_view_args = dict(sorted(new_view_args_dict.items())).values()
120+
return tuple(new_view_args)
121+
122+
for node in traced_module.graph.nodes:
123+
if self._node_need_rewrite(node):
124+
# Get the input tensor node
125+
input_tensor_node = node.args[0]
126+
127+
# --- Dependency on ShapeProp Results ---
128+
# input_shape is the static shape (e.g., batch_size, C, H, W)
129+
input_meta = input_tensor_node.meta.get("tensor_meta")
130+
if input_meta is None:
131+
raise RuntimeError(
132+
f"Node {input_tensor_node.name} lacks tensor_meta. Did ShapeProp run?"
133+
)
134+
135+
# Get the target shape arguments for view (e.g., 1, -1, 6, 64)
136+
input_shape = input_tensor_node.meta["tensor_meta"].shape
137+
view_args = node.args[1:]
138+
new_view_args = get_new_tuple_args(input_shape, view_args)
139+
140+
# --- Rebuild the view node ---
141+
# 1. Map the input tensor node to the new graph node
142+
new_input_node = val_map[input_tensor_node]
143+
144+
# 2. Insert the new view node into the new graph
145+
# with new_graph.inserting_after(new_input_node):
146+
new_node = new_graph.call_method(
147+
"view", args=(new_input_node, *new_view_args)
148+
)
149+
150+
# 3. Map the old node to the new node
151+
val_map[node] = new_node
152+
153+
else:
154+
# Copy other nodes to the new graph
155+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
156+
val_map[node] = new_node
157+
158+
# Replace the old graph with the new graph and return
159+
traced_module.graph = new_graph
160+
traced_module.recompile()
161+
return traced_module

graph_net/torch/static_to_dynamic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def make_config(self, pass_names=()):
4646
"naive_call_method_reshape_pass",
4747
"naive_call_method_expand_pass",
4848
"non_batch_call_method_expand_pass",
49+
"non_batch_call_method_view_pass",
4950
"non_batch_call_function_arange_pass", # typos: skip
5051
"non_batch_call_function_getitem_slice_pass",
5152
"non_batch_call_function_full_pass",

samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Data2VecVisionModel/graph_net.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"naive_call_method_reshape_pass",
2222
"naive_call_method_expand_pass",
2323
"non_batch_call_method_expand_pass",
24+
"non_batch_call_method_view_pass",
2425
"non_batch_call_function_arange_pass",
2526
"non_batch_call_function_getitem_slice_pass",
2627
"non_batch_call_function_full_pass",
@@ -29,4 +30,4 @@
2930
"non_batch_call_function_arange_plus_one_pass"
3031
],
3132
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
32-
}
33+
}

samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification/graph_net.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@
1515
"region:us"
1616
],
1717
"heuristic_tag": "computer_vision",
18+
"dimension_generalization_passes": [
19+
"non_batch_call_method_view_pass"
20+
],
1821
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
19-
}
22+
}

samples/transformers-auto-model/microsoft_swin-base-patch4-window12-384-in22k/graph_net.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@
2020
"region:us"
2121
],
2222
"heuristic_tag": "computer_vision",
23+
"dimension_generalization_passes": [
24+
"non_batch_call_method_view_pass"
25+
],
2326
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
24-
}
27+
}

samples/ultralytics/yolov3-tinyu/graph_net.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
"dynamic": false,
66
"source": "ultralytics",
77
"heuristic_tag": "computer_vision",
8-
"dimension_generalization_passes": [],
8+
"dimension_generalization_passes": [
9+
"non_batch_call_method_view_pass"
10+
],
911
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
10-
}
12+
}

samples/ultralytics/yolov9c/graph_net.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
"dynamic": false,
66
"source": "ultralytics",
77
"heuristic_tag": "computer_vision",
8-
"dimension_generalization_passes": [],
8+
"dimension_generalization_passes": [
9+
"non_batch_call_method_view_pass"
10+
],
911
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
10-
}
12+
}

0 commit comments

Comments
 (0)