Skip to content

Commit ebe9968

Browse files
committed
generate initial input_tensor_constraints.py
1 parent cba4e86 commit ebe9968

File tree

7 files changed

+807
-58
lines changed

7 files changed

+807
-58
lines changed

graph_net/constraint_util.py

Lines changed: 141 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,174 @@
11
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
2+
from graph_net.imp_util import load_module
3+
from graph_net.tensor_meta import TensorMeta
24
from typing import Callable
35
import copy
6+
import sys
7+
import os
8+
9+
10+
class UpdateInputTensorConstraints:
11+
def __init__(self, config=None):
12+
if config is None:
13+
config = {}
14+
self.config = self._make_config(**config)
15+
self.data_input_predicator = self._make_data_input_predicator(self.config)
16+
self.model_runnable_predicator = self._make_model_runnable_predicator(
17+
self.config
18+
)
19+
20+
def _make_data_input_predicator(self, config):
21+
module = load_module(config["data_input_predicator_filepath"])
22+
cls = getattr(module, config["data_input_predicator_class_name"])
23+
return cls(config["data_input_predicator_config"])
24+
25+
def _make_model_runnable_predicator(self, config):
26+
module = load_module(config["model_runnable_predicator_filepath"])
27+
cls = getattr(module, config["model_runnable_predicator_class_name"])
28+
return cls(config["model_runnable_predicator_config"])
29+
30+
def _make_config(
31+
self,
32+
data_input_predicator_filepath,
33+
model_runnable_predicator_filepath,
34+
data_input_predicator_class_name="DataInputPredicator",
35+
data_input_predicator_config=None,
36+
model_runnable_predicator_class_name="ModelRunner",
37+
model_runnable_predicator_config=None,
38+
):
39+
if data_input_predicator_config is None:
40+
data_input_predicator_config = {}
41+
if model_runnable_predicator_config is None:
42+
model_runnable_predicator_config = {}
43+
return {
44+
"data_input_predicator_filepath": data_input_predicator_filepath,
45+
"data_input_predicator_class_name": data_input_predicator_class_name,
46+
"data_input_predicator_config": data_input_predicator_config,
47+
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
48+
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
49+
"model_runnable_predicator_config": model_runnable_predicator_config,
50+
}
51+
52+
def __call__(self, model_path):
53+
tensor_metas = self._get_tensor_metas(model_path)
54+
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
55+
56+
def data_input_predicator(input_var_name):
57+
return self.data_input_predicator(model_path, input_var_name)
58+
59+
def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
60+
return self._is_dyn_dim_cstr_feasible(
61+
model_path, tensor_metas, dyn_dim_cstr
62+
)
63+
64+
dyn_dim_cstr = symbolize_data_input_dims(
65+
dyn_dim_cstr,
66+
is_data_input=data_input_predicator,
67+
is_dyn_dim_cstr_feasible=is_dyn_dim_cstr_feasible,
68+
)
69+
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
70+
71+
def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path):
72+
cstr_code = dyn_dim_cstr.serialize_to_py_str()
73+
with open(os.path.join(model_path, "input_tensor_constraints.py"), "w") as fp:
74+
fp.write(cstr_code)
75+
76+
def _get_tensor_metas(self, model_path):
77+
make = TensorMeta.unserialize_from_py_file
78+
return [
79+
*make(os.path.join(model_path, "input_meta.py")),
80+
*make(os.path.join(model_path, "weight_meta.py")),
81+
]
82+
83+
def _is_dyn_dim_cstr_feasible(
84+
self, model_path, tensor_metas, dyn_dim_cstr: DynamicDimConstraints
85+
):
86+
tensor_metas = copy.deepcopy(tensor_metas)
87+
update_tensor_metas_by_dyn_dim_cstr(tensor_metas, dyn_dim_cstr)
88+
weight_meta_code = "\n".join(
89+
tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas
90+
)
91+
import tempfile
92+
93+
with tempfile.TemporaryDirectory() as tmpdir:
94+
for filename in ["graph_net.json", "model.py"]:
95+
with open(os.path.join(tmpdir, filename), "w") as f:
96+
f.write(open(os.path.join(model_path, filename)).read())
97+
with open(os.path.join(tmpdir, "input_meta.py"), "w") as f:
98+
f.write("")
99+
with open(os.path.join(tmpdir, "weight_meta.py"), "w") as f:
100+
f.write(weight_meta_code)
101+
return self.model_runnable_predicator(tmpdir)
102+
103+
104+
def update_tensor_metas_by_dyn_dim_cstr(
105+
tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints
106+
):
107+
input_shapes = dyn_dim_cstr.get_reified_input_shapes()
108+
input_max_values = dyn_dim_cstr.get_reified_input_max_values()
109+
assert len(tensor_metas) == len(input_shapes)
110+
assert len(tensor_metas) == len(input_max_values)
111+
for i, tensor_meta in enumerate(tensor_metas):
112+
tensor_meta.shape = input_shapes[i]
113+
tensor_meta.max_val = input_max_values[i]
114+
115+
116+
def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
117+
named_shapes = [
118+
(shape, name)
119+
for tensor_meta in tensor_metas
120+
for name in [tensor_meta.name]
121+
for shape in [tensor_meta.shape]
122+
]
123+
named_max_values = [
124+
(max_val, name)
125+
for tensor_meta in tensor_metas
126+
for name in [tensor_meta.name]
127+
for max_val in [tensor_meta.max_val]
128+
]
129+
return DynamicDimConstraints.make_by_named_inputs(
130+
named_shapes=named_shapes,
131+
named_max_values=named_max_values,
132+
)
4133

5134

6135
def symbolize_data_input_dims(
7136
dyn_dim_cstr: DynamicDimConstraints,
8-
is_data_input: Callable[["input_var_name:str"], bool],
9-
is_input_shape_valid: Callable[[DynamicDimConstraints], bool],
137+
is_data_input: Callable[[str], bool],
138+
is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool],
10139
) -> DynamicDimConstraints | None:
11140
"""
141+
is_data_input: Callable[["input_var_name:str"], bool]
12142
Symbolizes data input dimensions as much as possible.
13143
Returns new DynamicDimConstraints if success.
14144
Returns None if no symbolicable dim .
15145
"""
16146
unqiue_dims = set()
17147

18148
def dumpy_filter_fn(input_name, input_idx, axis, dim):
19-
unqiue_dims.add(dim)
149+
if is_data_input(input_name):
150+
print("data_input", input_name, input_idx, axis, dim)
151+
unqiue_dims.add(dim)
20152
# No symbolization because of returning True
21153
return False
22154

23155
# Collect input dimensions into `unqiue_dims`
24156
assert dyn_dim_cstr.symbolize(dumpy_filter_fn) is None
25157
for picked_dim in unqiue_dims:
26-
tmp_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
158+
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
27159

28160
def filter_fn(input_name, input_idx, axis, dim):
29161
return is_data_input(input_name) and dim == picked_dim
30162

31-
symbol = tmp_dyn_dim_cstr.symbolize(filter_fn)
163+
symbol = cur_dyn_dim_cstr.symbolize(filter_fn)
32164
if symbol is None:
33165
continue
34166
sym2example_value = {symbol: picked_dim + 1}
35-
if not tmp_dyn_dim_cstr.check_delta_symbol2example_value(sym2example_value):
167+
if not cur_dyn_dim_cstr.check_delta_symbol2example_value(sym2example_value):
36168
continue
169+
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
37170
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
38-
if not is_input_shape_valid(tmp_dyn_dim_cstr):
171+
if not is_dyn_dim_cstr_feasible(tmp_dyn_dim_cstr):
39172
continue
40-
dyn_dim_cstr = tmp_dyn_dim_cstr
173+
dyn_dim_cstr = cur_dyn_dim_cstr
41174
return dyn_dim_cstr

graph_net/dynamic_dim_constraints.py

Lines changed: 80 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@dataclass
1111
class DynamicDimConstraints:
12-
kSymbolVarNamePrefix = "s"
12+
kSymbolVarNamePrefix = "S"
1313

1414
symbols: list[sympy.Symbol]
1515
kSymbols = "dynamic_dim_constraint_symbols"
@@ -22,27 +22,38 @@ class DynamicDimConstraints:
2222
kRelations = "dynamic_dim_constraint_relations"
2323

2424
# len(input_shapes) equals number of Model.forward arguments
25-
input_shapes: list[("var-name", tuple[sympy.Expr | int])]
25+
input_shapes: list[(tuple[sympy.Expr | int], "var-name")]
2626
kInputShapes = "dynamic_dim_constraint_input_shapes"
2727

2828
# len(input_shapes) equals number of Model.forward arguments
29-
input_max_values: list[("var-name", tuple[sympy.Expr | int | None])]
29+
input_max_values: list[(tuple[sympy.Expr | int | None], "var-name")]
3030
kInputMaxValues = "dynamic_dim_constraint_input_max_values"
3131

32+
@classmethod
33+
def make_by_named_inputs(cls, named_shapes, named_max_values):
34+
return cls(
35+
symbols=[],
36+
symbol2example_value={},
37+
relations=[],
38+
input_shapes=named_shapes,
39+
input_max_values=named_max_values,
40+
)
41+
3242
def symbolize(
3343
self,
34-
filter_fn=Callable[
35-
["input_name:str", "input_idx:int", "axis:int", "dim:int"], bool
36-
],
44+
filter_fn: Callable[[str, int, int, int], bool],
3745
) -> sympy.Symbol | None:
3846
"""
47+
filter_fn: Callable[
48+
["input_name:str", "input_idx:int", "axis:int", "dim:int"], bool
49+
]
3950
Returns created symbol.
4051
"""
4152
InputDim = namedtuple("InputDim", ["input_idx", "axis", "dim"])
4253
input_dims = [
4354
InputDim(input_idx, axis, dim)
4455
for input_idx, namedshape in enumerate(self.input_shapes)
45-
for input_name, shape in [namedshape]
56+
for shape, input_name in [namedshape]
4657
for axis, dim in enumerate(shape)
4758
if isinstance(dim, int)
4859
if filter_fn(input_name, input_idx, axis, dim)
@@ -54,41 +65,34 @@ def symbolize(
5465
dim = list(unique_dims)[0]
5566
new_sym = self._new_symbol(example_value=dim)
5667
for input_dix, axis, _ in input_dims:
57-
self.input_shapes[input_dix][1][axis] = new_sym
68+
self.input_shapes[input_dix][0][axis] = new_sym
5869
return new_sym
5970

60-
def _new_symbol(self, example_value):
61-
max_existed_seq_no = max(
62-
[
63-
-1,
64-
*(
65-
seq_no
66-
for symbol in self.symbols
67-
for seq_no in [int(symbol.name[1:])]
68-
),
69-
]
71+
def update_symbol2example_value(self, symbol2example_value: dict):
72+
self.symbol2example_value = self._merge_symbol2example_value(
73+
symbol2example_value
7074
)
71-
seq_no = max_existed_seq_no + 1
72-
symbol = sympy.Symbol(f"{self.kSymbolVarNamePrefix}{seq_no}")
73-
self.symbol2example_value[symbol] = example_value
74-
self.symbols.append(symbol)
75-
return symbol
76-
77-
def update_symbol2example_value(self, sym2example_value: dict):
78-
self.symbol2example_value = self.merge_symbol2example_value(sym2example_value)
7975
return self
8076

81-
def merge_symbol2example_value(self, sym2example_value: dict):
82-
return {
83-
k: v
84-
for k, v in [*self.symbol2example_value.items(), *sym2example_value.items()]
85-
}
77+
def get_reified_input_shapes(self):
78+
return [
79+
[self._try_reify(dim) for dim in shape] for shape, name in self.input_shapes
80+
]
81+
82+
def get_reified_input_max_values(self):
83+
return [self._try_reify(max_value) for max_value, name in self.input_max_values]
8684

87-
def check_delta_symbol2example_value(self, sym2example_value: dict):
88-
if len(sym2example_value) == 0:
85+
def _try_reify(self, dim):
86+
if isinstance(dim, sympy.Expr):
87+
dim = int(dim.subs(self.symbol2example_value))
88+
assert isinstance(dim, (int, type(None))), f"{type(dim)=} {dim=}"
89+
return dim
90+
91+
def check_delta_symbol2example_value(self, symbol2example_value: dict):
92+
if len(symbol2example_value) == 0:
8993
return True
9094

91-
sym2example_value = self.merge_symbol2example_value(sym2example_value)
95+
symbol2example_value = self._merge_symbol2example_value(symbol2example_value)
9296

9397
sym_exprs = [
9498
*self._get_sym_exprs_from_input_shapes(),
@@ -98,22 +102,7 @@ def check_delta_symbol2example_value(self, sym2example_value: dict):
98102
relations = [*self.relations, *(sym_expr > 0 for sym_expr in sym_exprs)]
99103

100104
return all(
101-
relation.subs(sym2example_value) == sympy.true for relation in relations
102-
)
103-
104-
def _get_sym_exprs_from_input_shapes(self):
105-
yield from (
106-
sym_dim
107-
for name, shape in self.input_shapes
108-
for sym_dim in shape
109-
if isinstance(sym_dim, sympy.Expr)
110-
)
111-
112-
def _get_sym_exprs_from_input_max_values(self):
113-
yield from (
114-
sym_value
115-
for name, sym_value in self.input_max_values
116-
if isinstance(sym_value, sympy.Expr)
105+
relation.subs(symbol2example_value) == sympy.true for relation in relations
117106
)
118107

119108
def serialize_to_py_str(self):
@@ -188,6 +177,47 @@ def load_module(cls, path, name="unamed"):
188177
spec.loader.exec_module(module)
189178
return module
190179

180+
def _new_symbol(self, example_value):
181+
max_existed_seq_no = max(
182+
[
183+
-1,
184+
*(
185+
seq_no
186+
for symbol in self.symbols
187+
for seq_no in [int(symbol.name[len(self.kSymbolVarNamePrefix) :])]
188+
),
189+
]
190+
)
191+
seq_no = max_existed_seq_no + 1
192+
symbol = sympy.Symbol(f"{self.kSymbolVarNamePrefix}{seq_no}")
193+
self.symbol2example_value[symbol] = example_value
194+
self.symbols.append(symbol)
195+
return symbol
196+
197+
def _merge_symbol2example_value(self, symbol2example_value: dict):
198+
return {
199+
k: v
200+
for k, v in [
201+
*self.symbol2example_value.items(),
202+
*symbol2example_value.items(),
203+
]
204+
}
205+
206+
def _get_sym_exprs_from_input_shapes(self):
207+
yield from (
208+
sym_dim
209+
for shape, name in self.input_shapes
210+
for sym_dim in shape
211+
if isinstance(sym_dim, sympy.Expr)
212+
)
213+
214+
def _get_sym_exprs_from_input_max_values(self):
215+
yield from (
216+
sym_value
217+
for sym_value, name in self.input_max_values
218+
if isinstance(sym_value, sympy.Expr)
219+
)
220+
191221

192222
if __name__ == "__main__":
193223
cstr_code = """

0 commit comments

Comments
 (0)