Skip to content

Commit 508a99f

Browse files
authored
DimensionGeneralizationPass (#384)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints * Introduce input_tensor_constraints.py using shape propagation logic. * support dimension generalization for torch.Tensor.view and torch.Tensor.reshape * 1) support dimension generalization for torch.Tensor.expand(); 2) fix bugs in generalization for torch.Tensor.view and torch.Tensor.reshape * dimension_generalization_passes * Refactored DimensionGeneralizationPass.__init__ to accept argument dim_axes_pairs, enabling targeted configuration for specific use cases
1 parent 932cd03 commit 508a99f

11 files changed

+526
-355
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# samples/timm/resnetaa50d.d_in12k
2-
# samples/transformers-auto-model/opus-mt-en-gmw
3-
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation
2+
samples/transformers-auto-model/opus-mt-en-gmw
3+
# samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/constraint_util.py

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
2+
from contextlib import AbstractContextManager
23
from graph_net.imp_util import load_module
34
from graph_net.tensor_meta import TensorMeta
45
from typing import Callable
@@ -21,6 +22,7 @@ def __init__(self, config=None):
2122
self.model_runnable_predicator = self._make_model_runnable_predicator(
2223
self.config
2324
)
25+
self.num_successful_handled_models = 0
2426

2527
def _make_data_input_predicator(self, config):
2628
module = load_module(config["data_input_predicator_filepath"])
@@ -45,6 +47,8 @@ def _make_config(
4547
dimension_generalizer_config=None,
4648
model_path_prefix="",
4749
resume=False,
50+
last_model_log_file=None,
51+
limits_successfully_handled_models=None,
4852
):
4953
if data_input_predicator_config is None:
5054
data_input_predicator_config = {}
@@ -64,6 +68,8 @@ def _make_config(
6468
"dimension_generalizer_filepath": dimension_generalizer_filepath,
6569
"dimension_generalizer_class_name": dimension_generalizer_class_name,
6670
"dimension_generalizer_config": dimension_generalizer_config,
71+
"last_model_log_file": last_model_log_file,
72+
"limits_successfully_handled_models": limits_successfully_handled_models,
6773
}
6874

6975
def __call__(self, model_path):
@@ -86,24 +92,42 @@ def __call__(self, model_path):
8692
def data_input_predicator(input_var_name):
8793
return self.data_input_predicator(model_path, input_var_name)
8894

89-
with self._try_dimension_generalization(
90-
model_path, tensor_metas
91-
) as tmp_model_path:
95+
def get_tmp_model_path_ctx_mgr(dim_axes_pairs):
96+
return self._try_dimension_generalization(
97+
dim_axes_pairs, model_path, tensor_metas
98+
)
9299

100+
def get_predicator_is_dyn_dim_cstr_feasible(tmp_model_path):
93101
def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
94102
return self._is_dyn_dim_cstr_feasible(
95103
tmp_model_path, tensor_metas, dyn_dim_cstr
96104
)
97105

98-
dyn_dim_cstr = symbolize_data_input_dims(
99-
dyn_dim_cstr,
100-
is_data_input=data_input_predicator,
101-
is_dyn_dim_cstr_feasible=is_dyn_dim_cstr_feasible,
102-
)
103-
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
106+
return is_dyn_dim_cstr_feasible
107+
108+
dyn_dim_cstr_feasibility_ctx_mgr = DynDimCstrFeasibilityContextManager(
109+
get_tmp_model_path_ctx_mgr=get_tmp_model_path_ctx_mgr,
110+
get_predicator_is_dyn_dim_cstr_feasible=get_predicator_is_dyn_dim_cstr_feasible,
111+
)
112+
dyn_dim_cstr = symbolize_data_input_dims(
113+
dyn_dim_cstr,
114+
is_data_input=data_input_predicator,
115+
dyn_dim_cstr_feasibility_ctx_mgr=dyn_dim_cstr_feasibility_ctx_mgr,
116+
)
117+
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
118+
if len(dyn_dim_cstr.symbols) > 0:
119+
self.num_successful_handled_models += 1
120+
limits = self.config["limits_successfully_handled_models"]
121+
if limits is not None:
122+
if self.num_successful_handled_models > limits:
123+
print(
124+
"`num_successful_handled_models` exceeds config `limits_successfully_handled_models`",
125+
file=sys.stderr,
126+
)
127+
sys.exit(0)
104128

105129
@contextmanager
106-
def _try_dimension_generalization(self, model_path, tensor_metas):
130+
def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas):
107131
if self.config["dimension_generalizer_filepath"] is None:
108132
yield model_path
109133
return
@@ -115,20 +139,23 @@ def _try_dimension_generalization(self, model_path, tensor_metas):
115139
load_module(self.config["dimension_generalizer_filepath"]),
116140
self.config["dimension_generalizer_class_name"],
117141
)
118-
pass_obj = decorator_cls(self.config["dimension_generalizer_config"])(model)
119-
if not pass_obj.need_rewrite():
142+
dim_generalizer = decorator_cls(self.config["dimension_generalizer_config"])
143+
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
144+
if not dim_gen_pass.need_rewrite():
120145
yield model_path
121146
return
122147
from dataclasses import asdict
123148

124149
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
125-
graph_module = pass_obj.rewrite_with_tensor_meta_attrs_list(
126-
tensor_meta_attrs_list
150+
graph_module = dim_gen_pass.rewrite_with_tensor_meta_attrs_list(
151+
tensor_meta_attrs_list=tensor_meta_attrs_list,
127152
)
128153
with tempfile.TemporaryDirectory() as tmp_dir:
129154
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
130-
pass_obj.save_graph_module(graph_module, tmp_dir)
131-
shutil.copy(Path(tmp_dir) / "model.py", Path("/tmp/a.py"))
155+
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
156+
if self.config["last_model_log_file"] is not None:
157+
log_file = Path(self.config["last_model_log_file"])
158+
shutil.copy(Path(tmp_dir) / "model.py", log_file)
132159
yield tmp_dir
133160
# shutil.copytree(Path(tmp_dir), Path(model_path), dirs_exist_ok=True)
134161

@@ -190,10 +217,40 @@ def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
190217
)
191218

192219

220+
class DynDimCstrFeasibilityPredicator:
221+
def __init__(
222+
self, is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool]
223+
):
224+
self.is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible
225+
226+
def __call__(self, dyn_dim_cstr: DynamicDimConstraints) -> bool:
227+
return self.is_dyn_dim_cstr_feasible(dyn_dim_cstr)
228+
229+
230+
class DynDimCstrFeasibilityContextManager:
231+
def __init__(
232+
self,
233+
get_tmp_model_path_ctx_mgr,
234+
get_predicator_is_dyn_dim_cstr_feasible,
235+
):
236+
self.get_tmp_model_path_ctx_mgr = get_tmp_model_path_ctx_mgr
237+
self.get_predicator_is_dyn_dim_cstr_feasible = (
238+
get_predicator_is_dyn_dim_cstr_feasible
239+
)
240+
241+
@contextmanager
242+
def __call__(
243+
self, dim_axes_pairs
244+
) -> AbstractContextManager[DynDimCstrFeasibilityPredicator]:
245+
with self.get_tmp_model_path_ctx_mgr(dim_axes_pairs) as tmp_model_apth:
246+
predicator = self.get_predicator_is_dyn_dim_cstr_feasible(tmp_model_apth)
247+
yield DynDimCstrFeasibilityPredicator(predicator)
248+
249+
193250
def symbolize_data_input_dims(
194251
dyn_dim_cstr: DynamicDimConstraints,
195252
is_data_input: Callable[[str], bool],
196-
is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool],
253+
dyn_dim_cstr_feasibility_ctx_mgr: DynDimCstrFeasibilityContextManager,
197254
) -> DynamicDimConstraints | None:
198255
"""
199256
is_data_input: Callable[["input_var_name:str"], bool]
@@ -202,18 +259,21 @@ def symbolize_data_input_dims(
202259
Returns None if no symbolicable dim .
203260
"""
204261
unqiue_dims = []
262+
dim2axes = {}
205263

206264
def dumpy_filter_fn(input_name, input_idx, axis, dim):
207265
if is_data_input(input_name):
208266
print("data_input", input_name, input_idx, axis, dim)
209267
if dim not in unqiue_dims:
210268
unqiue_dims.append(dim)
211-
# No symbolization because of returning True
269+
dim2axes[dim] = []
270+
dim2axes[dim].append(axis)
271+
# No symbolization by returning False
212272
return False
213273

214274
# Collect input dimensions into `unqiue_dims`
215275
assert dyn_dim_cstr.symbolize(dumpy_filter_fn) is None
216-
for picked_dim in unqiue_dims:
276+
for i, picked_dim in enumerate(unqiue_dims):
217277
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
218278

219279
def filter_fn(input_name, input_idx, axis, dim):
@@ -229,9 +289,15 @@ def filter_fn(input_name, input_idx, axis, dim):
229289
sym2example_value = {symbol: picked_dim + 1}
230290
if not cur_dyn_dim_cstr.check_delta_symbol2example_value(sym2example_value):
231291
continue
232-
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
233-
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
234-
if not is_dyn_dim_cstr_feasible(tmp_dyn_dim_cstr):
235-
continue
236-
dyn_dim_cstr = cur_dyn_dim_cstr
292+
dim_axes_pairs = tuple(
293+
(dim, axes) for dim in unqiue_dims[: i + 1] for axes in [dim2axes[dim]]
294+
)
295+
with dyn_dim_cstr_feasibility_ctx_mgr(
296+
dim_axes_pairs
297+
) as is_dyn_dim_cstr_feasible:
298+
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
299+
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
300+
if not is_dyn_dim_cstr_feasible(tmp_dyn_dim_cstr):
301+
continue
302+
dyn_dim_cstr = cur_dyn_dim_cstr
237303
return dyn_dim_cstr

graph_net/test/batch_init_input_tensor_constraints_test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ config_json_str=$(cat <<EOF
1818
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1919
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
2020
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
21-
"dimension_generalizer_class_name": "StaticToDynamic"
21+
"dimension_generalizer_class_name": "StaticToDynamic",
22+
"last_model_log_file": "/tmp/a.py"
2223
}
2324
}
2425
EOF

graph_net/test/shape_prop_batch_init_input_tensor_constraints_test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ config_json_str=$(cat <<EOF
1818
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1919
"model_runnable_predicator_class_name": "ShapePropagatablePredicator",
2020
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
21-
"dimension_generalizer_class_name": "StaticToDynamic"
21+
"dimension_generalizer_class_name": "StaticToDynamic",
22+
"last_model_log_file": "/tmp/a.py"
2223
}
2324
}
2425
EOF
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from graph_net.torch.dim_gen_passes.pass_base import DimensionGeneralizationPass
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
import torch.fx as fx
3+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
from torch.fx.passes.shape_prop import ShapeProp
5+
from graph_net.torch.utils import apply_templates
6+
from pathlib import Path
7+
import inspect
8+
from typing import Any
9+
from contextlib import contextmanager
10+
from torch.export import export
11+
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
12+
13+
14+
class ConcretePass(DimensionGeneralizationPass):
15+
def __init__(self, *args, **kwargs):
16+
super().__init__(*args, **kwargs)
17+
18+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
19+
if 0 not in self.axes:
20+
return False
21+
for node in traced_module.graph.nodes:
22+
if node.op == "call_method" and node.target == "expand":
23+
return True
24+
return False
25+
26+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
27+
"""
28+
Fx Pass: Replaces hardcoded constants in 'expand' ops that match an input tensor dimension
29+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size (axis 0).
30+
"""
31+
# Create a new graph to hold the rewritten nodes
32+
new_graph = fx.Graph()
33+
34+
# Create a map to link nodes from the old graph to nodes in the new graph
35+
val_map = {}
36+
37+
for node in traced_module.graph.nodes:
38+
if node.op == "call_method" and node.target == "expand":
39+
# Get the input tensor node
40+
input_tensor_node = node.args[0]
41+
# Get the target shape arguments for expand (e.g., 1, 4, 6, 64)
42+
expand_args = node.args[1:]
43+
44+
# --- Dependency on ShapeProp Results ---
45+
# input_shape is the static shape (e.g., batch_size, C, H, W)
46+
input_meta = input_tensor_node.meta.get("tensor_meta")
47+
if input_meta is None:
48+
raise RuntimeError(
49+
f"Node {input_tensor_node.name} lacks tensor_meta. Did ShapeProp run?"
50+
)
51+
52+
input_shape = input_meta.shape
53+
54+
# Find the new list of expand arguments
55+
new_expand_args = []
56+
57+
# Iterate over the target dimensions of expand (dim0, dim1, ...)
58+
for i, target_dim in enumerate(expand_args):
59+
# 1. Handle dynamic dimensions (e.g., -1 or non-integer values)
60+
if not isinstance(target_dim, int) or target_dim < 1:
61+
new_expand_args.append(
62+
val_map[target_dim] if target_dim in val_map else target_dim
63+
)
64+
continue
65+
66+
# 2. Handle hardcoded constants (e.g., 1, 6, 64)
67+
68+
# --- Core Logic: Find the matching dynamic axis ---
69+
70+
# Default: Keep the hardcoded constant if no matching dynamic axis is found
71+
best_match = target_dim
72+
matched_axis = -1
73+
74+
axis_idx = i
75+
input_dim_size = input_shape[i]
76+
if target_dim == input_dim_size:
77+
if axis_idx == 0:
78+
matched_axis = axis_idx
79+
elif axis_idx > 0 and input_dim_size > 1:
80+
matched_axis = axis_idx
81+
else:
82+
# Do nothing.
83+
pass
84+
85+
if matched_axis != -1:
86+
# Found a matching dynamic axis (matched_axis), replace it with a size() call
87+
88+
# 1. Create a call to size(axis) in the new graph
89+
# NOTE: input_tensor_node must first be mapped to a new graph node via val_map
90+
new_input_node = val_map[input_tensor_node]
91+
92+
# Use the size() method to retrieve the dynamic dimension
93+
size_node = new_graph.call_method(
94+
"size", args=(new_input_node, matched_axis)
95+
)
96+
97+
best_match = size_node
98+
99+
new_expand_args.append(best_match)
100+
101+
# --- Rebuild the expand node ---
102+
# 1. Map the input tensor node to the new graph node
103+
new_input_node = val_map[input_tensor_node]
104+
105+
# 2. Insert the new expand node into the new graph
106+
# with new_graph.inserting_after(new_input_node):
107+
new_node = new_graph.call_method(
108+
"expand", args=(new_input_node, *new_expand_args)
109+
)
110+
111+
# 3. Map the old node to the new node
112+
val_map[node] = new_node
113+
114+
else:
115+
# Copy other nodes to the new graph
116+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
117+
val_map[node] = new_node
118+
119+
# Replace the old graph with the new graph and return
120+
traced_module.graph = new_graph
121+
traced_module.recompile()
122+
return traced_module

0 commit comments

Comments
 (0)