Skip to content

Commit 2d35e3b

Browse files
committed
more dimension generalization passes for token dimension
1 parent 8173fea commit 2d35e3b

19 files changed

+614
-73
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# samples/timm/resnetaa50d.d_in12k
2-
samples/transformers-auto-model/opus-mt-en-gmw
3-
# samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation
2+
# samples/transformers-auto-model/opus-mt-en-gmw
3+
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/constraint_util.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import shutil
1313
from pathlib import Path
1414
import json
15+
from dataclasses import asdict
1516

1617

1718
class UpdateInputTensorConstraints:
@@ -75,7 +76,6 @@ def _make_config(
7576

7677
def __call__(self, model_path):
7778
model_path = os.path.join(self.config["model_path_prefix"], model_path)
78-
print(f"{model_path=}")
7979
cstr_path = os.path.join(model_path, "input_tensor_constraints.py")
8080
if (
8181
self.config["resume"]
@@ -143,15 +143,13 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, tensor_metas
143143
)
144144
dim_generalizer = decorator_cls(self.config["dimension_generalizer_config"])
145145
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
146-
if not dim_gen_pass.need_rewrite():
146+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
147+
inputs = dim_gen_pass.create_inputs_by_metas(tensor_meta_attrs_list)
148+
if not dim_gen_pass.need_rewrite(inputs):
147149
yield model_path, ()
148150
return
149-
from dataclasses import asdict
150151

151-
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
152-
graph_module = dim_gen_pass.rewrite_with_tensor_meta_attrs_list(
153-
tensor_meta_attrs_list=tensor_meta_attrs_list,
154-
)
152+
graph_module = dim_gen_pass.rewrite(inputs)
155153
with tempfile.TemporaryDirectory() as tmp_dir:
156154
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
157155
dim_gen_pass.save_graph_module(graph_module, tmp_dir)

graph_net/imp_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
def load_module(path, name="unamed"):
55
spec = imp.spec_from_file_location(name, path)
66
module = imp.module_from_spec(spec)
7+
module.__file__ = path
78
spec.loader.exec_module(module)
89
module.__graph_net_file_path__ = path
910
return module

graph_net/test/shape_prop_test.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
config_json_str=$(cat <<EOF
7+
{
8+
"handler_path": "$GRAPH_NET_ROOT/torch/constraint_util.py",
9+
"handler_class_name": "ShapePropagatablePredicator"
10+
}
11+
EOF
12+
)
13+
CONFIG=$(echo $config_json_str | base64 -w 0)
14+
15+
python3 -m graph_net.model_path_handler --model-path $1 --handler-config=$CONFIG

graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
22
import torch.fx as fx
33
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
import os
45

56

67
class ConcretePass(DimensionGeneralizationPass):
78
def __init__(self, *args, **kwargs):
89
super().__init__(*args, **kwargs)
910

11+
def get_pass_name(cls) -> bool:
12+
return os.path.basename(__file__)[:-3]
13+
1014
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
1115
if 0 not in self.axes:
1216
return False

graph_net/torch/dim_gen_passes/naive_call_method_reshape_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
22
import torch.fx as fx
33
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
import os
45

56

67
class ConcretePass(DimensionGeneralizationPass):
78
def __init__(self, *args, **kwargs):
89
super().__init__(*args, **kwargs)
910

11+
def get_pass_name(cls) -> bool:
12+
return os.path.basename(__file__)[:-3]
13+
1014
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
1115
if 0 not in self.axes:
1216
return False

graph_net/torch/dim_gen_passes/naive_call_method_view_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
22
import torch.fx as fx
33
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
import os
45

56

67
class ConcretePass(DimensionGeneralizationPass):
78
def __init__(self, *args, **kwargs):
89
super().__init__(*args, **kwargs)
910

11+
def get_pass_name(cls) -> bool:
12+
return os.path.basename(__file__)[:-3]
13+
1014
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
1115
if 0 not in self.axes:
1216
return False

graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,38 @@
22
import torch.fx as fx
33
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
44
from collections import namedtuple
5+
import os
56

67

78
class ConcretePass(DimensionGeneralizationPass):
89
def __init__(self, *args, **kwargs):
910
super().__init__(*args, **kwargs)
1011

12+
def get_pass_name(cls) -> bool:
13+
return os.path.basename(__file__)[:-3]
14+
1115
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
1216
# non batch
1317
if 0 in self.axes:
1418
return False
1519
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
1620

21+
def node_target(self):
22+
return torch.arange
23+
1724
def _node_need_rewrite(self, node) -> bool:
18-
return (
19-
node.op == "call_function"
20-
and node.target == torch.arange
21-
and len(node.args) >= 2
22-
and node.args[0] == 0
23-
and node.args[1] == self.dim
24-
)
25+
if not (node.op == "call_function"):
26+
return False
27+
if not (node.target == self.node_target()):
28+
return False
29+
if len(node.args) == 1:
30+
return node.args[0] == self.dim
31+
elif len(node.args) == 2:
32+
return node.args[0] == 0 and node.args[1] == self.dim
33+
else:
34+
return False
2535

2636
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
27-
"""
28-
Fx Pass: Replaces hardcoded constants in 'torch.arange' ops that match an input tensor dimension
29-
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size.
30-
"""
3137
# Create a new graph to hold the rewritten nodes
3238
new_graph = fx.Graph()
3339

@@ -50,8 +56,8 @@ def try_reset_last_node_axis(node, new_node):
5056
last_node_axis = NodeAxis(node=new_node, shape_axis=axis)
5157
return
5258

53-
def get_new_node_arg(i, arg):
54-
if i != 1:
59+
def get_new_node_arg(i, arg, len_args):
60+
if not ((len_args == 2 and i == 1) or (len_args == 1 and i == 0)):
5561
return val_map[arg] if arg in val_map else arg
5662
# i == 1
5763
assert arg == self.dim
@@ -70,12 +76,14 @@ def create_new_node(node):
7076
try_reset_last_node_axis(node=node, new_node=new_node)
7177
return new_node
7278

73-
new_arange_args = tuple(
74-
get_new_node_arg(i, arg) for i, arg in enumerate(node.args)
79+
new_node_args = tuple(
80+
get_new_node_arg(i, arg, len(node.args))
81+
for i, arg in enumerate(node.args)
7582
)
7683

77-
# --- Rebuild the torch.arange node ---
78-
new_node = new_graph.call_function(torch.arange, args=new_arange_args)
84+
new_node = new_graph.call_function(
85+
self.node_target(), args=new_node_args, kwargs=node.kwargs
86+
)
7987

8088
return new_node
8189

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
import torch.fx as fx
3+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
from collections import namedtuple
5+
import os
6+
import operator
7+
8+
9+
class ConcretePass(DimensionGeneralizationPass):
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
13+
def get_pass_name(cls) -> bool:
14+
return os.path.basename(__file__)[:-3]
15+
16+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
17+
# non batch
18+
if 0 in self.axes:
19+
return False
20+
if not any(self._is_pivote_node(node) for node in traced_module.graph.nodes):
21+
return False
22+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
23+
24+
def _node_target(self):
25+
return torch.arange
26+
27+
def _is_pivote_node(self, node):
28+
if not (node.op == "call_function"):
29+
return False
30+
if not (self._is_pivote_target(node.target)):
31+
return False
32+
if not (len(node.args) == 1):
33+
return False
34+
input_node = node.args[0]
35+
input_node_meta = input_node.meta.get("tensor_meta")
36+
if not (input_node_meta is not None):
37+
return False
38+
if not (hasattr(input_node_meta, "shape")):
39+
return False
40+
shape = input_node_meta.shape
41+
if not any(dim == self.dim + self._dyn_dim_delta() for dim in shape):
42+
return False
43+
return True
44+
45+
def _is_pivote_target(self, target):
46+
return target == torch.triu
47+
48+
def _node_need_rewrite(self, node) -> bool:
49+
if not (node.op == "call_function"):
50+
return False
51+
if not (node.target == self._node_target()):
52+
return False
53+
if not (len(node.args) == 1):
54+
return False
55+
if not (node.args[0] == self.dim + self._dyn_dim_delta()):
56+
return False
57+
return True
58+
59+
def _dyn_dim_delta(self):
60+
return 1
61+
62+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
63+
# Create a new graph to hold the rewritten nodes
64+
new_graph = fx.Graph()
65+
66+
# Create a map to link nodes from the old graph to nodes in the new graph
67+
val_map = {}
68+
69+
NodeAxis = namedtuple("NodeAxis", ["node", "shape_axis"])
70+
last_node_axis = None
71+
72+
def try_reset_last_node_axis(node, new_node):
73+
nonlocal last_node_axis
74+
node_meta = node.meta.get("tensor_meta")
75+
if node_meta is None:
76+
return
77+
if not hasattr(node_meta, "shape"):
78+
return
79+
for axis, dim in enumerate(node_meta.shape):
80+
if not (dim == self.dim and axis > 0):
81+
continue
82+
last_node_axis = NodeAxis(node=new_node, shape_axis=axis)
83+
return
84+
85+
def get_new_node_dim(dim):
86+
if not (dim == self.dim + self._dyn_dim_delta()):
87+
return val_map[dim] if dim in val_map else dim
88+
assert dim == self.dim + self._dyn_dim_delta()
89+
90+
# Use the size() method to retrieve the dynamic dimension
91+
size_node = new_graph.call_method(
92+
"size", args=(last_node_axis.node, last_node_axis.shape_axis)
93+
)
94+
plus_one_node = new_graph.call_function(operator.add, args=(size_node, 1))
95+
return plus_one_node
96+
97+
def create_new_node(node):
98+
if not (self._node_need_rewrite(node) and last_node_axis is not None):
99+
# Copy other nodes to the new graph
100+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
101+
try_reset_last_node_axis(node=node, new_node=new_node)
102+
return new_node
103+
104+
new_node_dim = get_new_node_dim(dim=node.args[0])
105+
106+
new_node = new_graph.call_function(
107+
self._node_target(), args=(new_node_dim,), kwargs=node.kwargs
108+
)
109+
110+
return new_node
111+
112+
for node in traced_module.graph.nodes:
113+
val_map[node] = create_new_node(node)
114+
115+
# Replace the old graph with the new graph and return
116+
traced_module.graph = new_graph
117+
traced_module.recompile()
118+
return traced_module

0 commit comments

Comments
 (0)