Skip to content

Commit 8768f95

Browse files
authored
Generalize sequence dimension (#386)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints * Introduce input_tensor_constraints.py using shape propagation logic. * support dimension generalization for torch.Tensor.view and torch.Tensor.reshape * 1) support dimension generalization for torch.Tensor.expand(); 2) fix bugs in generalization for torch.Tensor.view and torch.Tensor.reshape * dimension_generalization_passes * Refactored DimensionGeneralizationPass.__init__ to accept argument dim_axes_pairs, enabling targeted configuration for specific use cases * save dimension generalization pass names into graph_net.json * Generalize sequence dimension
1 parent 508a99f commit 8768f95

9 files changed

+337
-40
lines changed

graph_net/constraint_util.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
import shutil
1313
from pathlib import Path
14+
import json
1415

1516

1617
class UpdateInputTensorConstraints:
@@ -109,12 +110,13 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
109110
get_tmp_model_path_ctx_mgr=get_tmp_model_path_ctx_mgr,
110111
get_predicator_is_dyn_dim_cstr_feasible=get_predicator_is_dyn_dim_cstr_feasible,
111112
)
112-
dyn_dim_cstr = symbolize_data_input_dims(
113+
dyn_dim_cstr, dim_gen_pass_names = symbolize_data_input_dims(
113114
dyn_dim_cstr,
114115
is_data_input=data_input_predicator,
115116
dyn_dim_cstr_feasibility_ctx_mgr=dyn_dim_cstr_feasibility_ctx_mgr,
116117
)
117118
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
119+
self._save_dim_gen_pass_names(dim_gen_pass_names, model_path)
118120
if len(dyn_dim_cstr.symbols) > 0:
119121
self.num_successful_handled_models += 1
120122
limits = self.config["limits_successfully_handled_models"]
@@ -129,7 +131,7 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
129131
@contextmanager
130132
def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas):
131133
if self.config["dimension_generalizer_filepath"] is None:
132-
yield model_path
134+
yield model_path, ()
133135
return
134136
py_module = load_module(os.path.join(model_path, "model.py"))
135137
GraphModule = getattr(py_module, "GraphModule")
@@ -142,7 +144,7 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas
142144
dim_generalizer = decorator_cls(self.config["dimension_generalizer_config"])
143145
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
144146
if not dim_gen_pass.need_rewrite():
145-
yield model_path
147+
yield model_path, ()
146148
return
147149
from dataclasses import asdict
148150

@@ -156,8 +158,15 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas
156158
if self.config["last_model_log_file"] is not None:
157159
log_file = Path(self.config["last_model_log_file"])
158160
shutil.copy(Path(tmp_dir) / "model.py", log_file)
159-
yield tmp_dir
160-
# shutil.copytree(Path(tmp_dir), Path(model_path), dirs_exist_ok=True)
161+
yield tmp_dir, dim_gen_pass.get_pass_names()
162+
163+
def _save_dim_gen_pass_names(self, dim_gen_pass_names, model_path):
164+
from graph_net.graph_net_json_file_util import kDimensionGeneralizationPasses
165+
166+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
167+
graph_net_json = json.loads(graph_net_json_file_path.read_text())
168+
graph_net_json[kDimensionGeneralizationPasses] = list(dim_gen_pass_names)
169+
graph_net_json_file_path.write_text(json.dumps(graph_net_json))
161170

162171
def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path):
163172
cstr_code = dyn_dim_cstr.serialize_to_py_str()
@@ -219,9 +228,12 @@ def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
219228

220229
class DynDimCstrFeasibilityPredicator:
221230
def __init__(
222-
self, is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool]
231+
self,
232+
is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool],
233+
dim_gen_pass_names: tuple[str],
223234
):
224235
self.is_dyn_dim_cstr_feasible = is_dyn_dim_cstr_feasible
236+
self.dim_gen_pass_names = dim_gen_pass_names
225237

226238
def __call__(self, dyn_dim_cstr: DynamicDimConstraints) -> bool:
227239
return self.is_dyn_dim_cstr_feasible(dyn_dim_cstr)
@@ -242,16 +254,17 @@ def __init__(
242254
def __call__(
243255
self, dim_axes_pairs
244256
) -> AbstractContextManager[DynDimCstrFeasibilityPredicator]:
245-
with self.get_tmp_model_path_ctx_mgr(dim_axes_pairs) as tmp_model_apth:
257+
ctx_mgr = self.get_tmp_model_path_ctx_mgr
258+
with ctx_mgr(dim_axes_pairs) as (tmp_model_apth, dg_pass_names):
246259
predicator = self.get_predicator_is_dyn_dim_cstr_feasible(tmp_model_apth)
247-
yield DynDimCstrFeasibilityPredicator(predicator)
260+
yield DynDimCstrFeasibilityPredicator(predicator, dg_pass_names)
248261

249262

250263
def symbolize_data_input_dims(
251264
dyn_dim_cstr: DynamicDimConstraints,
252265
is_data_input: Callable[[str], bool],
253266
dyn_dim_cstr_feasibility_ctx_mgr: DynDimCstrFeasibilityContextManager,
254-
) -> DynamicDimConstraints | None:
267+
) -> (DynamicDimConstraints | None, tuple[str]):
255268
"""
256269
is_data_input: Callable[["input_var_name:str"], bool]
257270
Symbolizes data input dimensions as much as possible.
@@ -273,6 +286,21 @@ def dumpy_filter_fn(input_name, input_idx, axis, dim):
273286

274287
# Collect input dimensions into `unqiue_dims`
275288
assert dyn_dim_cstr.symbolize(dumpy_filter_fn) is None
289+
total_dim_gen_pass_names = ()
290+
291+
def append_dim_gen_pass_names(dim_gen_pass_names):
292+
nonlocal total_dim_gen_pass_names
293+
total_dim_gen_pass_names = tuple(
294+
[
295+
*total_dim_gen_pass_names,
296+
*(
297+
pass_name
298+
for pass_name in dim_gen_pass_names
299+
if pass_name not in total_dim_gen_pass_names
300+
),
301+
]
302+
)
303+
276304
for i, picked_dim in enumerate(unqiue_dims):
277305
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
278306

@@ -292,12 +320,12 @@ def filter_fn(input_name, input_idx, axis, dim):
292320
dim_axes_pairs = tuple(
293321
(dim, axes) for dim in unqiue_dims[: i + 1] for axes in [dim2axes[dim]]
294322
)
295-
with dyn_dim_cstr_feasibility_ctx_mgr(
296-
dim_axes_pairs
297-
) as is_dyn_dim_cstr_feasible:
323+
ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
324+
with ctx_mgr(dim_axes_pairs) as dyn_dim_cstr_feasibility:
298325
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
299326
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
300-
if not is_dyn_dim_cstr_feasible(tmp_dyn_dim_cstr):
327+
if not dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr):
301328
continue
302329
dyn_dim_cstr = cur_dyn_dim_cstr
303-
return dyn_dim_cstr
330+
append_dim_gen_pass_names(dyn_dim_cstr_feasibility.dim_gen_pass_names)
331+
return dyn_dim_cstr, total_dim_gen_pass_names
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
kDimensionGeneralizationPasses = "dimension_generalization_passes"

graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import torch
22
import torch.fx as fx
33
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4-
from torch.fx.passes.shape_prop import ShapeProp
5-
from graph_net.torch.utils import apply_templates
6-
from pathlib import Path
7-
import inspect
8-
from typing import Any
9-
from contextlib import contextmanager
10-
from torch.export import export
11-
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
124

135

146
class ConcretePass(DimensionGeneralizationPass):

graph_net/torch/dim_gen_passes/naive_call_method_reshape_pass.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import torch
22
import torch.fx as fx
33
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4-
from torch.fx.passes.shape_prop import ShapeProp
5-
from graph_net.torch.utils import apply_templates
6-
from pathlib import Path
7-
import inspect
8-
from typing import Any
9-
from contextlib import contextmanager
10-
from torch.export import export
11-
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
124

135

146
class ConcretePass(DimensionGeneralizationPass):

graph_net/torch/dim_gen_passes/naive_call_method_view_pass.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import torch
22
import torch.fx as fx
33
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4-
from torch.fx.passes.shape_prop import ShapeProp
5-
from graph_net.torch.utils import apply_templates
6-
from pathlib import Path
7-
import inspect
8-
from typing import Any
9-
from contextlib import contextmanager
10-
from torch.export import export
11-
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
124

135

146
class ConcretePass(DimensionGeneralizationPass):
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import torch.fx as fx
3+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
from collections import namedtuple
5+
6+
7+
class ConcretePass(DimensionGeneralizationPass):
8+
def __init__(self, *args, **kwargs):
9+
super().__init__(*args, **kwargs)
10+
11+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
12+
# non batch
13+
if 0 in self.axes:
14+
return False
15+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
16+
17+
def _node_need_rewrite(self, node) -> bool:
18+
return (
19+
node.op == "call_function"
20+
and node.target == torch.arange
21+
and len(node.args) >= 2
22+
and node.args[0] == 0
23+
and node.args[1] == self.dim
24+
)
25+
26+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
27+
"""
28+
Fx Pass: Replaces hardcoded constants in 'torch.arange' ops that match an input tensor dimension
29+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size.
30+
"""
31+
# Create a new graph to hold the rewritten nodes
32+
new_graph = fx.Graph()
33+
34+
# Create a map to link nodes from the old graph to nodes in the new graph
35+
val_map = {}
36+
37+
NodeAxis = namedtuple("NodeAxis", ["node", "shape_axis"])
38+
last_node_axis = None
39+
40+
def try_reset_last_node_axis(node, new_node):
41+
nonlocal last_node_axis
42+
node_meta = node.meta.get("tensor_meta")
43+
if node_meta is None:
44+
return
45+
if not hasattr(node_meta, "shape"):
46+
return
47+
for axis, dim in enumerate(node_meta.shape):
48+
if not (dim == self.dim and axis > 0):
49+
continue
50+
last_node_axis = NodeAxis(node=new_node, shape_axis=axis)
51+
return
52+
53+
def get_new_node_arg(i, arg):
54+
if i != 1:
55+
return val_map[arg] if arg in val_map else arg
56+
# i == 1
57+
assert arg == self.dim
58+
59+
# Use the size() method to retrieve the dynamic dimension
60+
size_node = new_graph.call_method(
61+
"size", args=(last_node_axis.node, last_node_axis.shape_axis)
62+
)
63+
64+
return size_node
65+
66+
def create_new_node(node):
67+
if not (self._node_need_rewrite(node) and last_node_axis is not None):
68+
# Copy other nodes to the new graph
69+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
70+
try_reset_last_node_axis(node=node, new_node=new_node)
71+
return new_node
72+
73+
new_arange_args = tuple(
74+
get_new_node_arg(i, arg) for i, arg in enumerate(node.args)
75+
)
76+
77+
# --- Rebuild the torch.arange node ---
78+
new_node = new_graph.call_function(torch.arange, args=new_arange_args)
79+
80+
return new_node
81+
82+
for node in traced_module.graph.nodes:
83+
val_map[node] = create_new_node(node)
84+
85+
# Replace the old graph with the new graph and return
86+
traced_module.graph = new_graph
87+
traced_module.recompile()
88+
return traced_module
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import torch.fx as fx
3+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
from collections import namedtuple
5+
import operator
6+
import copy
7+
8+
9+
class ConcretePass(DimensionGeneralizationPass):
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
13+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
14+
# non batch
15+
if 0 in self.axes:
16+
return False
17+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
18+
19+
def _node_need_rewrite(self, node) -> bool:
20+
if not (node.op == "call_function"):
21+
return False
22+
if not (node.target == operator.getitem):
23+
return False
24+
if not isinstance(node.args[1], tuple):
25+
return False
26+
if not any(self._slice_need_rewrite(slice_obj) for slice_obj in node.args[1]):
27+
return False
28+
return True
29+
30+
def _slice_need_rewrite(self, slice_obj) -> bool:
31+
if not isinstance(slice_obj, slice):
32+
return False
33+
return (
34+
slice_obj.stop == self.dim
35+
and (slice_obj.start is None or slice_obj.start == 0)
36+
and (slice_obj.step is None or slice_obj.step == 1)
37+
)
38+
39+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
40+
"""
41+
Fx Pass: Replaces hardcoded constants in 'operator.getitem' ops that match an input tensor dimension
42+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size.
43+
"""
44+
# Create a new graph to hold the rewritten nodes
45+
new_graph = fx.Graph()
46+
47+
# Create a map to link nodes from the old graph to nodes in the new graph
48+
val_map = {}
49+
50+
NodeAxis = namedtuple("NodeAxis", ["node", "shape_axis"])
51+
last_node_axis = None
52+
53+
def try_reset_last_node_axis(node, new_node):
54+
nonlocal last_node_axis
55+
node_meta = node.meta.get("tensor_meta")
56+
if node_meta is None:
57+
return
58+
if not hasattr(node_meta, "shape"):
59+
return
60+
for axis, dim in enumerate(node_meta.shape):
61+
if not (dim == self.dim and axis > 0):
62+
continue
63+
last_node_axis = NodeAxis(node=new_node, shape_axis=axis)
64+
return
65+
66+
def val_map_contains(key):
67+
if isinstance(key, slice):
68+
return False
69+
return key in val_map
70+
71+
def get_new_getitem_tuple_elem(elem):
72+
if not (
73+
isinstance(elem, slice)
74+
and self._slice_need_rewrite(elem)
75+
and last_node_axis is not None
76+
):
77+
return val_map[elem] if val_map_contains(elem) else elem
78+
79+
slice_obj = copy.deepcopy(elem)
80+
assert slice_obj.stop == self.dim
81+
82+
# Use the size() method to retrieve the dynamic dimension
83+
size_node = new_graph.call_method(
84+
"size", args=(last_node_axis.node, last_node_axis.shape_axis)
85+
)
86+
return slice(slice_obj.start, size_node, slice_obj.step)
87+
88+
def get_new_getitem_arg(arg):
89+
if not isinstance(arg, tuple):
90+
return val_map[arg] if arg in val_map else arg
91+
return tuple(get_new_getitem_tuple_elem(elem) for elem in arg)
92+
93+
def create_new_node(node):
94+
if not self._node_need_rewrite(node):
95+
# Copy other nodes to the new graph
96+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
97+
try_reset_last_node_axis(node=node, new_node=new_node)
98+
return new_node
99+
100+
new_gettiem_args = tuple(get_new_getitem_arg(arg) for arg in node.args)
101+
102+
# --- Rebuild the operator.getitem node ---
103+
new_node = new_graph.call_function(operator.getitem, args=new_gettiem_args)
104+
105+
return new_node
106+
107+
for node in traced_module.graph.nodes:
108+
val_map[node] = create_new_node(node)
109+
110+
# Replace the old graph with the new graph and return
111+
traced_module.graph = new_graph
112+
traced_module.recompile()
113+
return traced_module

0 commit comments

Comments
 (0)