Skip to content

Commit a6b3c15

Browse files
committed
add run_model.py to improve reusability.
1 parent 5e309b0 commit a6b3c15

File tree

3 files changed

+124
-13
lines changed

3 files changed

+124
-13
lines changed

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,26 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
55

66
# input model path
7-
MODEL_PATH_IN_SAMPLES=/timm/resnet18
8-
extractor_config_json_str=$(cat <<EOF
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
decorator_config_json_str=$(cat <<EOF
910
{
10-
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
11-
"custom_extractor_config": {
12-
"output_dir": "/tmp/naive_decompose_workspace",
13-
"split_positions": [8, 16, 32],
14-
"group_head_and_tail": true,
15-
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
16-
"filter_config": {}
11+
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
12+
"decorator_config": {
13+
"name": "$MODEL_NAME",
14+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
15+
"custom_extractor_config": {
16+
"output_dir": "/tmp/naive_decompose_workspace",
17+
"split_positions": [8, 16, 32],
18+
"group_head_and_tail": true,
19+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
20+
"filter_config": {}
21+
}
1722
}
1823
}
1924
EOF
2025
)
21-
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
26+
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
2227

2328
mkdir -p /tmp/naive_decompose_workspace
24-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
29+
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG

graph_net/torch/extractor.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import json
44
import shutil
55
from typing import Union, Callable
6-
from . import utils
7-
from .fx_graph_serialize_util import serialize_graph_module_to_str
6+
from graph_net.torch import utils
7+
from graph_net.torch.fx_graph_serialize_util import serialize_graph_module_to_str
88

99
torch._dynamo.config.capture_scalar_outputs = True
1010
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -13,6 +13,34 @@
1313
torch._dynamo.config.allow_rnn = True
1414

1515

16+
# used as configuration of python3 -m graph_net.torch.run_model
17+
class RunModelDecorator:
18+
def __init__(self, config):
19+
self.config = self.make_config(**config)
20+
21+
def __call__(self, model):
22+
return extract(**self.config)(model)
23+
24+
def make_config(
25+
self,
26+
name=None,
27+
dynamic=True,
28+
placeholder_auto_rename=False,
29+
custom_extractor_path: str = None,
30+
custom_extractor_config: dict = None,
31+
):
32+
assert name is not None
33+
return {
34+
"name": name,
35+
"dynamic": dynamic,
36+
"placeholder_auto_rename": placeholder_auto_rename,
37+
"extractor_config": {
38+
"custom_extractor_path": custom_extractor_path,
39+
"custom_extractor_config": custom_extractor_config,
40+
},
41+
}
42+
43+
1644
class GraphExtractor:
1745
def __init__(
1846
self,

graph_net/torch/run_model.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from . import utils
2+
import argparse
3+
import importlib.util
4+
import inspect
5+
import torch
6+
import logging
7+
from pathlib import Path
8+
from typing import Type, Any
9+
import sys
10+
import json
11+
import base64
12+
from contextlib import contextmanager
13+
14+
15+
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
16+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
17+
unnamed = importlib.util.module_from_spec(spec)
18+
spec.loader.exec_module(unnamed)
19+
model_class = getattr(unnamed, class_name, None)
20+
return model_class
21+
22+
23+
def _convert_to_dict(config_str):
24+
if config_str is None:
25+
return {}
26+
config_str = base64.b64decode(config_str).decode("utf-8")
27+
config = json.loads(config_str)
28+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
29+
return config
30+
31+
32+
def _get_decorator(args):
33+
if args.decorator_config is None:
34+
return lambda model: model
35+
decorator_config = _convert_to_dict(args.decorator_config)
36+
if "decorator_path" not in decorator_config:
37+
return lambda model: model
38+
decorator_class = load_class_from_file(
39+
decorator_config["decorator_path"], class_name="RunModelDecorator"
40+
)
41+
return decorator_class(decorator_config.get("decorator_config", {}))
42+
43+
44+
def main(args):
45+
model_path = args.model_path
46+
model_class = load_class_from_file(
47+
f"{model_path}/model.py", class_name="GraphModule"
48+
)
49+
assert model_class is not None
50+
model = model_class()
51+
print(f"{model_path=}")
52+
53+
model = _get_decorator(args)(model)
54+
55+
inputs_params = utils.load_converted_from_text(f"{model_path}")
56+
params = inputs_params["weight_info"]
57+
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
58+
59+
model(**state_dict)
60+
61+
62+
if __name__ == "__main__":
63+
parser = argparse.ArgumentParser(description="load and run model")
64+
parser.add_argument(
65+
"--model-path",
66+
type=str,
67+
required=True,
68+
help="Path to folder e.g '../../samples/torch/resnet18'",
69+
)
70+
parser.add_argument(
71+
"--decorator-config",
72+
type=str,
73+
required=False,
74+
default=None,
75+
help="decorator configuration string",
76+
)
77+
args = parser.parse_args()
78+
main(args=args)

0 commit comments

Comments
 (0)