Skip to content

Commit 32d68ab

Browse files
authored
Provide utilities to generate input meta constraints (#378)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * backup code * generate initial input_tensor_constraints.py * reorder symbol names * remove unused input_max_values * batch initial input_tensor_constraints.py
1 parent 8abea5d commit 32d68ab

File tree

12 files changed

+4458
-0
lines changed

12 files changed

+4458
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
samples/timm/resnetaa50d.d_in12k
2+
samples/timm/regnetx_016.pycls_in1k
3+
samples/timm/repghostnet_130.in1k

graph_net/constraint_util.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
2+
from graph_net.imp_util import load_module
3+
from graph_net.tensor_meta import TensorMeta
4+
from typing import Callable
5+
import copy
6+
import sys
7+
import os
8+
9+
10+
class UpdateInputTensorConstraints:
11+
def __init__(self, config=None):
12+
if config is None:
13+
config = {}
14+
self.config = self._make_config(**config)
15+
self.data_input_predicator = self._make_data_input_predicator(self.config)
16+
self.model_runnable_predicator = self._make_model_runnable_predicator(
17+
self.config
18+
)
19+
20+
def _make_data_input_predicator(self, config):
21+
module = load_module(config["data_input_predicator_filepath"])
22+
cls = getattr(module, config["data_input_predicator_class_name"])
23+
return cls(config["data_input_predicator_config"])
24+
25+
def _make_model_runnable_predicator(self, config):
26+
module = load_module(config["model_runnable_predicator_filepath"])
27+
cls = getattr(module, config["model_runnable_predicator_class_name"])
28+
return cls(config["model_runnable_predicator_config"])
29+
30+
def _make_config(
31+
self,
32+
data_input_predicator_filepath,
33+
model_runnable_predicator_filepath,
34+
data_input_predicator_class_name="DataInputPredicator",
35+
data_input_predicator_config=None,
36+
model_runnable_predicator_class_name="ModelRunner",
37+
model_runnable_predicator_config=None,
38+
model_path_prefix="",
39+
):
40+
if data_input_predicator_config is None:
41+
data_input_predicator_config = {}
42+
if model_runnable_predicator_config is None:
43+
model_runnable_predicator_config = {}
44+
return {
45+
"data_input_predicator_filepath": data_input_predicator_filepath,
46+
"data_input_predicator_class_name": data_input_predicator_class_name,
47+
"data_input_predicator_config": data_input_predicator_config,
48+
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
49+
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
50+
"model_runnable_predicator_config": model_runnable_predicator_config,
51+
"model_path_prefix": model_path_prefix,
52+
}
53+
54+
def __call__(self, model_path):
55+
model_path = os.path.join(self.config["model_path_prefix"], model_path)
56+
tensor_metas = self._get_tensor_metas(model_path)
57+
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
58+
59+
def data_input_predicator(input_var_name):
60+
return self.data_input_predicator(model_path, input_var_name)
61+
62+
def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
63+
return self._is_dyn_dim_cstr_feasible(
64+
model_path, tensor_metas, dyn_dim_cstr
65+
)
66+
67+
dyn_dim_cstr = symbolize_data_input_dims(
68+
dyn_dim_cstr,
69+
is_data_input=data_input_predicator,
70+
is_dyn_dim_cstr_feasible=is_dyn_dim_cstr_feasible,
71+
)
72+
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
73+
74+
def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path):
75+
cstr_code = dyn_dim_cstr.serialize_to_py_str()
76+
with open(os.path.join(model_path, "input_tensor_constraints.py"), "w") as fp:
77+
fp.write(cstr_code)
78+
79+
def _get_tensor_metas(self, model_path):
80+
make = TensorMeta.unserialize_from_py_file
81+
return [
82+
*make(os.path.join(model_path, "input_meta.py")),
83+
*make(os.path.join(model_path, "weight_meta.py")),
84+
]
85+
86+
def _is_dyn_dim_cstr_feasible(
87+
self, model_path, tensor_metas, dyn_dim_cstr: DynamicDimConstraints
88+
):
89+
tensor_metas = copy.deepcopy(tensor_metas)
90+
update_tensor_metas_by_dyn_dim_cstr(tensor_metas, dyn_dim_cstr)
91+
weight_meta_code = "\n".join(
92+
tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas
93+
)
94+
import tempfile
95+
96+
with tempfile.TemporaryDirectory() as tmpdir:
97+
for filename in ["graph_net.json", "model.py"]:
98+
with open(os.path.join(tmpdir, filename), "w") as f:
99+
f.write(open(os.path.join(model_path, filename)).read())
100+
with open(os.path.join(tmpdir, "input_meta.py"), "w") as f:
101+
f.write("")
102+
with open(os.path.join(tmpdir, "weight_meta.py"), "w") as f:
103+
f.write(weight_meta_code)
104+
return self.model_runnable_predicator(tmpdir)
105+
106+
107+
def update_tensor_metas_by_dyn_dim_cstr(
108+
tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints
109+
):
110+
input_shapes = dyn_dim_cstr.get_reified_input_shapes()
111+
assert len(tensor_metas) == len(input_shapes)
112+
for i, tensor_meta in enumerate(tensor_metas):
113+
tensor_meta.shape = input_shapes[i]
114+
115+
116+
def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
117+
named_shapes = [
118+
(shape, name)
119+
for tensor_meta in tensor_metas
120+
for name in [tensor_meta.name]
121+
for shape in [tensor_meta.shape]
122+
]
123+
return DynamicDimConstraints.make_by_named_inputs(
124+
named_shapes=named_shapes,
125+
)
126+
127+
128+
def symbolize_data_input_dims(
129+
dyn_dim_cstr: DynamicDimConstraints,
130+
is_data_input: Callable[[str], bool],
131+
is_dyn_dim_cstr_feasible: Callable[[DynamicDimConstraints], bool],
132+
) -> DynamicDimConstraints | None:
133+
"""
134+
is_data_input: Callable[["input_var_name:str"], bool]
135+
Symbolizes data input dimensions as much as possible.
136+
Returns new DynamicDimConstraints if success.
137+
Returns None if no symbolicable dim .
138+
"""
139+
unqiue_dims = []
140+
141+
def dumpy_filter_fn(input_name, input_idx, axis, dim):
142+
if is_data_input(input_name):
143+
print("data_input", input_name, input_idx, axis, dim)
144+
if dim not in unqiue_dims:
145+
unqiue_dims.append(dim)
146+
# No symbolization because of returning True
147+
return False
148+
149+
# Collect input dimensions into `unqiue_dims`
150+
assert dyn_dim_cstr.symbolize(dumpy_filter_fn) is None
151+
for picked_dim in unqiue_dims:
152+
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
153+
154+
def filter_fn(input_name, input_idx, axis, dim):
155+
return is_data_input(input_name) and dim == picked_dim
156+
157+
symbol = cur_dyn_dim_cstr.symbolize(filter_fn)
158+
if symbol is None:
159+
continue
160+
sym2example_value = {symbol: picked_dim + 1}
161+
if not cur_dyn_dim_cstr.check_delta_symbol2example_value(sym2example_value):
162+
continue
163+
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
164+
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
165+
if not is_dyn_dim_cstr_feasible(tmp_dyn_dim_cstr):
166+
continue
167+
dyn_dim_cstr = cur_dyn_dim_cstr
168+
return dyn_dim_cstr

0 commit comments

Comments
 (0)