Skip to content

Commit 54c2597

Browse files
authored
Gen input tensor constraints (#382)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints * Introduce input_tensor_constraints.py using shape propagation logic. * support dimension generalization for torch.Tensor.view and torch.Tensor.reshape
1 parent 687d973 commit 54c2597

12 files changed

+463
-29
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
samples/timm/resnetaa50d.d_in12k
2-
samples/transformers-auto-model/opus-mt-en-gmw
3-
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation
2+
# samples/transformers-auto-model/opus-mt-en-gmw
3+
# samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/constraint_util.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import copy
77
import sys
88
import os
9+
from contextlib import contextmanager
10+
import tempfile
11+
import shutil
12+
from pathlib import Path
913

1014

1115
class UpdateInputTensorConstraints:
@@ -33,25 +37,33 @@ def _make_config(
3337
data_input_predicator_filepath,
3438
model_runnable_predicator_filepath,
3539
data_input_predicator_class_name="DataInputPredicator",
36-
data_input_predicator_config=None,
3740
model_runnable_predicator_class_name="ModelRunner",
41+
data_input_predicator_config=None,
3842
model_runnable_predicator_config=None,
43+
dimension_generalizer_filepath=None,
44+
dimension_generalizer_class_name="StaticToDynamic",
45+
dimension_generalizer_config=None,
3946
model_path_prefix="",
4047
resume=False,
4148
):
4249
if data_input_predicator_config is None:
4350
data_input_predicator_config = {}
4451
if model_runnable_predicator_config is None:
4552
model_runnable_predicator_config = {}
53+
if dimension_generalizer_config is None:
54+
dimension_generalizer_config = {}
4655
return {
56+
"resume": resume,
57+
"model_path_prefix": model_path_prefix,
4758
"data_input_predicator_filepath": data_input_predicator_filepath,
4859
"data_input_predicator_class_name": data_input_predicator_class_name,
4960
"data_input_predicator_config": data_input_predicator_config,
5061
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
5162
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
5263
"model_runnable_predicator_config": model_runnable_predicator_config,
53-
"model_path_prefix": model_path_prefix,
54-
"resume": resume,
64+
"dimension_generalizer_filepath": dimension_generalizer_filepath,
65+
"dimension_generalizer_class_name": dimension_generalizer_class_name,
66+
"dimension_generalizer_config": dimension_generalizer_config,
5567
}
5668

5769
def __call__(self, model_path):
@@ -74,17 +86,51 @@ def __call__(self, model_path):
7486
def data_input_predicator(input_var_name):
7587
return self.data_input_predicator(model_path, input_var_name)
7688

77-
def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
78-
return self._is_dyn_dim_cstr_feasible(
79-
model_path, tensor_metas, dyn_dim_cstr
80-
)
89+
with self._try_dimension_generalization(
90+
model_path, tensor_metas
91+
) as tmp_model_path:
92+
93+
def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
94+
return self._is_dyn_dim_cstr_feasible(
95+
tmp_model_path, tensor_metas, dyn_dim_cstr
96+
)
8197

82-
dyn_dim_cstr = symbolize_data_input_dims(
83-
dyn_dim_cstr,
84-
is_data_input=data_input_predicator,
85-
is_dyn_dim_cstr_feasible=is_dyn_dim_cstr_feasible,
98+
dyn_dim_cstr = symbolize_data_input_dims(
99+
dyn_dim_cstr,
100+
is_data_input=data_input_predicator,
101+
is_dyn_dim_cstr_feasible=is_dyn_dim_cstr_feasible,
102+
)
103+
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
104+
105+
@contextmanager
106+
def _try_dimension_generalization(self, model_path, tensor_metas):
107+
if self.config["dimension_generalizer_filepath"] is None:
108+
yield model_path
109+
return
110+
py_module = load_module(os.path.join(model_path, "model.py"))
111+
GraphModule = getattr(py_module, "GraphModule")
112+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
113+
model = GraphModule()
114+
decorator_cls = getattr(
115+
load_module(self.config["dimension_generalizer_filepath"]),
116+
self.config["dimension_generalizer_class_name"],
117+
)
118+
pass_obj = decorator_cls(self.config["dimension_generalizer_config"])(model)
119+
if not pass_obj.need_rewrite():
120+
yield model_path
121+
return
122+
from dataclasses import asdict
123+
124+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
125+
graph_module = pass_obj.rewrite_with_tensor_meta_attrs_list(
126+
tensor_meta_attrs_list
86127
)
87-
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
128+
with tempfile.TemporaryDirectory() as tmp_dir:
129+
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
130+
pass_obj.save_graph_module(graph_module, tmp_dir)
131+
shutil.copy(Path(tmp_dir) / "model.py", Path("/tmp/a.py"))
132+
yield tmp_dir
133+
# shutil.copytree(Path(tmp_dir), Path(model_path), dirs_exist_ok=True)
88134

89135
def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path):
90136
cstr_code = dyn_dim_cstr.serialize_to_py_str()
@@ -106,7 +152,6 @@ def _is_dyn_dim_cstr_feasible(
106152
weight_meta_code = "\n".join(
107153
tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas
108154
)
109-
import tempfile
110155

111156
with tempfile.TemporaryDirectory() as tmpdir:
112157
for filename in ["graph_net.json", "model.py"]:

graph_net/imp_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ def load_module(path, name="unamed"):
55
spec = imp.spec_from_file_location(name, path)
66
module = imp.module_from_spec(spec)
77
spec.loader.exec_module(module)
8+
module.__graph_net_file_path__ = path
89
return module

graph_net/model_path_handler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import traceback
12
import argparse
23
import importlib.util
34
from graph_net.imp_util import load_module
@@ -44,7 +45,11 @@ def main(args):
4445
except KeyboardInterrupt:
4546
sys.exit(-1)
4647
except Exception as e:
47-
pass
48+
print("--- Concise Error Message ---")
49+
print(e)
50+
51+
print("\n--- Full Traceback ---")
52+
traceback.print_exc()
4853

4954

5055
def _get_model_paths(args):

graph_net/test/batch_init_input_tensor_constraints_test.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
1212
"handler_class_name": "UpdateInputTensorConstraints",
1313
"handler_config": {
14-
"resume": true,
14+
"resume": false,
1515
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1616
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1717
"data_input_predicator_class_name": "NaiveDataInputPredicator",
1818
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
19-
"model_runnable_predicator_class_name": "ModelRunnablePredicator"
19+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
20+
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
21+
"dimension_generalizer_class_name": "StaticToDynamic"
2022
}
2123
}
2224
EOF
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
config_json_str=$(cat <<EOF
8+
{
9+
"decorator_path": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
10+
"decorator_class_name": "StaticToDynamic",
11+
"decorator_config": {
12+
"output_dir": ""
13+
}
14+
}
15+
EOF
16+
)
17+
CONFIG=$(echo $config_json_str | base64 -w 0)
18+
19+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/transformers-auto-model/opus-mt-en-gmw --decorator-config=$CONFIG

graph_net/test/shape_prop_batch_init_input_tensor_constraints_test.sh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@ config_json_str=$(cat <<EOF
1616
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
1717
"data_input_predicator_class_name": "NaiveDataInputPredicator",
1818
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
19-
"model_runnable_predicator_class_name": "RunModelPredicator",
20-
"model_runnable_predicator_config": {
21-
"decorator_path": "$GRAPH_NET_ROOT/torch/shape_prop.py",
22-
"decorator_class_name": "ShapePropagate"
23-
}
19+
"model_runnable_predicator_class_name": "ShapePropagatablePredicator",
20+
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
21+
"dimension_generalizer_class_name": "StaticToDynamic"
2422
}
2523
}
2624
EOF

graph_net/torch/constraint_util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ def __call__(self, model_path):
1919
return os.system(cmd) == 0
2020

2121

22+
class ShapePropagatablePredicator:
23+
def __init__(self, config=None):
24+
if config is None:
25+
config = {}
26+
import graph_net
27+
28+
graph_net_root = os.path.dirname(graph_net.__file__)
29+
decorator_config = {
30+
"decorator_path": f"{graph_net_root}/torch/shape_prop.py",
31+
"decorator_class_name": "ShapePropagate",
32+
}
33+
self.predicator = RunModelPredicator(decorator_config)
34+
35+
def __call__(self, model_path):
36+
return self.predicator(model_path)
37+
38+
2239
class RunModelPredicator:
2340
def __init__(self, config=None):
2441
if config is None:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
3+
4+
def parse_sole_graph_module(module, inputs):
5+
traced_module = None
6+
traced_sample_inputs = None
7+
8+
def my_backend(gm, sample_inputs):
9+
nonlocal traced_module
10+
traced_module = gm
11+
nonlocal traced_sample_inputs
12+
traced_sample_inputs = sample_inputs
13+
return gm.forward
14+
15+
torch.compile(module, backend=my_backend)(*inputs)
16+
assert traced_module is not None
17+
assert all(id(a) == id(b) for a, b in zip(inputs, traced_sample_inputs))
18+
for node in traced_module.graph.nodes:
19+
if node.op != "placeholder":
20+
continue
21+
assert node.target[:2] == "L_" or node.target[:2] == "l_", f"{node.target=}"
22+
node.target = node.target[2:]
23+
assert node.name[:2] == "L_" or node.name[:2] == "l_", f"{node.name=}"
24+
node.name = node.name[2:]
25+
return traced_module

graph_net/torch/run_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
1717
unnamed = importlib.util.module_from_spec(spec)
1818
spec.loader.exec_module(unnamed)
1919
model_class = getattr(unnamed, class_name, None)
20+
setattr(model_class, "__graph_net_file_path__", file_path)
2021
return model_class
2122

2223

0 commit comments

Comments
 (0)