Skip to content

Commit 1a42094

Browse files
authored
add robustness code for generating input tensor constraints (#379)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints
1 parent 32d68ab commit 1a42094

File tree

5 files changed

+35
-6
lines changed

5 files changed

+35
-6
lines changed
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
samples/timm/resnetaa50d.d_in12k
2-
samples/timm/regnetx_016.pycls_in1k
3-
samples/timm/repghostnet_130.in1k
1+
samples/transformers-auto-model/opus-mt-en-gmw
2+
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/constraint_util.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from graph_net.imp_util import load_module
33
from graph_net.tensor_meta import TensorMeta
44
from typing import Callable
5+
import functools
56
import copy
67
import sys
78
import os
@@ -36,6 +37,7 @@ def _make_config(
3637
model_runnable_predicator_class_name="ModelRunner",
3738
model_runnable_predicator_config=None,
3839
model_path_prefix="",
40+
resume=False,
3941
):
4042
if data_input_predicator_config is None:
4143
data_input_predicator_config = {}
@@ -49,10 +51,23 @@ def _make_config(
4951
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
5052
"model_runnable_predicator_config": model_runnable_predicator_config,
5153
"model_path_prefix": model_path_prefix,
54+
"resume": resume,
5255
}
5356

5457
def __call__(self, model_path):
5558
model_path = os.path.join(self.config["model_path_prefix"], model_path)
59+
print(f"{model_path=}")
60+
cstr_path = os.path.join(model_path, "input_tensor_constraints.py")
61+
if (
62+
self.config["resume"]
63+
and os.path.exists(cstr_path)
64+
and DynamicDimConstraints.kSymbols in open(cstr_path).read()
65+
):
66+
module = load_module(cstr_path)
67+
symbols = getattr(module, DynamicDimConstraints.kSymbols)
68+
if len(symbols) > 0:
69+
return
70+
5671
tensor_metas = self._get_tensor_metas(model_path)
5772
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
5873

@@ -111,6 +126,11 @@ def update_tensor_metas_by_dyn_dim_cstr(
111126
assert len(tensor_metas) == len(input_shapes)
112127
for i, tensor_meta in enumerate(tensor_metas):
113128
tensor_meta.shape = input_shapes[i]
129+
if tensor_meta.data is not None:
130+
assert isinstance(tensor_meta.data, (list, tuple))
131+
size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)
132+
doubled_data = [*tensor_meta.data, *tensor_meta.data]
133+
tensor_meta.data = doubled_data[:size]
114134

115135

116136
def make_dyn_dim_cstr_from_tensor_metas(tensor_metas: list[TensorMeta]):
@@ -152,7 +172,11 @@ def dumpy_filter_fn(input_name, input_idx, axis, dim):
152172
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
153173

154174
def filter_fn(input_name, input_idx, axis, dim):
155-
return is_data_input(input_name) and dim == picked_dim
175+
return (
176+
is_data_input(input_name)
177+
and dim == picked_dim
178+
and (dim > 1 or axis == 0)
179+
)
156180

157181
symbol = cur_dyn_dim_cstr.symbolize(filter_fn)
158182
if symbol is None:

graph_net/model_path_handler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def main(args):
3939
handler = _get_handler(args)
4040
for model_path in _get_model_paths(args):
4141
print(f"{model_path=}")
42-
handler(model_path)
42+
try:
43+
handler(model_path)
44+
except KeyboardInterrupt:
45+
sys.exit(-1)
46+
except Exception as e:
47+
pass
4348

4449

4550
def _get_model_paths(args):

graph_net/test/batch_init_input_tensor_constraints_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
1212
"handler_class_name": "UpdateInputTensorConstraints",
1313
"handler_config": {
14+
"resume": true,
1415
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1516
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1617
"data_input_predicator_class_name": "NaiveDataInputPredicator",

graph_net/torch/constraint_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self, config):
77
self.config = config
88

99
def __call__(self, model_path, input_var_name: str) -> bool:
10-
return not ("self_" in input_var_name and "_modules_" in input_var_name)
10+
return not ("_self_" in input_var_name)
1111

1212

1313
class ModelRunnablePredicator:

0 commit comments

Comments
 (0)