Skip to content

Commit 687d973

Browse files
authored
Gen input tensor constraints (#380)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints * Introduce input_tensor_constraints.py using shape propagation logic.
1 parent 1a42094 commit 687d973

File tree

6 files changed

+83
-1
lines changed

6 files changed

+83
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
samples/timm/resnetaa50d.d_in12k
12
samples/transformers-auto-model/opus-mt-en-gmw
23
samples/transformers-auto-model/Michielo_mt5-small_nl-en_translation

graph_net/model_path_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _get_model_paths(args):
5858
for line in f
5959
for clean_line in [line.strip()]
6060
if len(clean_line) > 0
61+
if not clean_line.startswith("#")
6162
)
6263

6364

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
12+
"handler_class_name": "UpdateInputTensorConstraints",
13+
"handler_config": {
14+
"resume": false,
15+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
16+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
17+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
18+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
19+
"model_runnable_predicator_class_name": "RunModelPredicator",
20+
"model_runnable_predicator_config": {
21+
"decorator_path": "$GRAPH_NET_ROOT/torch/shape_prop.py",
22+
"decorator_class_name": "ShapePropagate"
23+
}
24+
}
25+
}
26+
EOF
27+
)
28+
CONFIG=$(echo $config_json_str | base64 -w 0)
29+
30+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/small_torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/constraint_util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,21 @@ def __init__(self, config):
1717
def __call__(self, model_path):
1818
cmd = f"{sys.executable} -m graph_net.torch.run_model --model-path {model_path}"
1919
return os.system(cmd) == 0
20+
21+
22+
class RunModelPredicator:
23+
def __init__(self, config=None):
24+
if config is None:
25+
config = {}
26+
self.config = config
27+
28+
def __call__(self, model_path):
29+
import json
30+
import base64
31+
32+
json_string = json.dumps(self.config)
33+
json_bytes = json_string.encode("utf-8")
34+
b64_encoded_bytes = base64.b64encode(json_bytes)
35+
decorator_config = b64_encoded_bytes.decode("utf-8")
36+
cmd = f"{sys.executable} -m graph_net.torch.run_model --model-path {model_path} --decorator-config {decorator_config}"
37+
return os.system(cmd) == 0

graph_net/torch/run_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def _get_decorator(args):
3535
decorator_config = _convert_to_dict(args.decorator_config)
3636
if "decorator_path" not in decorator_config:
3737
return lambda model: model
38+
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
3839
decorator_class = load_class_from_file(
39-
decorator_config["decorator_path"], class_name="RunModelDecorator"
40+
decorator_config["decorator_path"],
41+
class_name=class_name,
4042
)
4143
return decorator_class(decorator_config.get("decorator_config", {}))
4244

graph_net/torch/shape_prop.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
from typing import Union, Callable
3+
from torch.fx.passes.shape_prop import ShapeProp
4+
import inspect
5+
6+
7+
# used as configuration of python3 -m graph_net.torch.run_model
8+
class ShapePropagate:
9+
def __init__(self, config=None):
10+
if config is None:
11+
config = {}
12+
self.config = config
13+
14+
def __call__(self, module):
15+
return ShapePropModule(self.config, module)
16+
17+
18+
class ShapePropModule(torch.nn.Module):
19+
def __init__(self, config, module):
20+
super().__init__()
21+
self.config = config
22+
self.module = module
23+
24+
def forward(self, *args, **kwargs):
25+
assert len(args) == 0
26+
traced_model = torch.fx.symbolic_trace(self.module)
27+
inputs = [
28+
kwargs[name] for name in inspect.signature(self.module.forward).parameters
29+
]
30+
propagated_model = ShapeProp(traced_model).propagate(*inputs)

0 commit comments

Comments
 (0)