Skip to content

Commit ca20508

Browse files
committed
add a script to check fusability of a given model
1 parent adff744 commit ca20508

18 files changed

+94
-4517
lines changed

graph_net/test/dimension_generalization_test.sh

100644100755
File mode changed.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18d.ra2_in1k
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
checker_config_json_str=$(cat <<EOF
10+
{
11+
"post_extract_process_config": {
12+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
13+
"post_extract_process_class_name": "GraphFullyFusionable"
14+
}
15+
}
16+
EOF
17+
)
18+
CHECKER_CONFIG=$(echo $checker_config_json_str | base64 -w 0)
19+
20+
python3 -m graph_net.torch.check_model_fusability --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --checker-config=$CHECKER_CONFIG

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ decorator_config_json_str=$(cat <<EOF
2020
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
2121
"filter_config": {},
2222
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
23-
"post_extract_process_class_name": "PostExtractProcess"
23+
"post_extract_process_class_name": "GraphFullyFusionable"
2424
}
2525
}
2626
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from . import utils
2+
import argparse
3+
import importlib.util
4+
import inspect
5+
from graph_net.imp_util import load_module
6+
import torch
7+
import logging
8+
from pathlib import Path
9+
from typing import Type, Any
10+
import sys
11+
import json
12+
import base64
13+
from contextlib import contextmanager
14+
15+
16+
def _load_class_from_file(file_path, class_name):
17+
module = load_module(file_path)
18+
return getattr(module, class_name)
19+
20+
21+
def _convert_to_dict(config_str):
22+
if config_str is None:
23+
return {}
24+
config_str = base64.b64decode(config_str).decode("utf-8")
25+
config = json.loads(config_str)
26+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
27+
return config
28+
29+
30+
def _get_checker(args):
31+
if args.checker_config is None:
32+
return lambda model_path: model_path
33+
checker_config = _convert_to_dict(args.checker_config).get(
34+
"post_extract_process_config"
35+
)
36+
checker_class = _load_class_from_file(
37+
checker_config["post_extract_process_path"],
38+
class_name=checker_config["post_extract_process_class_name"],
39+
)
40+
return checker_class(checker_config.get("checker_config", {}))
41+
42+
43+
def main(args):
44+
checker = _get_checker(args)
45+
model_path = args.model_path
46+
print(f"{model_path=}")
47+
try:
48+
checker(model_path)
49+
except KeyboardInterrupt:
50+
sys.exit(-1)
51+
except Exception as e:
52+
print(e)
53+
54+
55+
if __name__ == "__main__":
56+
parser = argparse.ArgumentParser(description="load and run model")
57+
parser.add_argument(
58+
"--model-path",
59+
type=str,
60+
required=True,
61+
help="Path to folder e.g '../../samples/torch/resnet18'",
62+
)
63+
parser.add_argument(
64+
"--checker-config",
65+
type=str,
66+
required=False,
67+
default=None,
68+
help="checker configuration string",
69+
)
70+
args = parser.parse_args()
71+
main(args=args)

graph_net/torch/naive_graph_decomposer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ def make_post_extract_process(self, config):
117117
if config["post_extract_process_path"] is None:
118118
return None
119119
module = imp_util.load_module(config["post_extract_process_path"])
120-
return module.PostExtractProcess(config["post_extract_process_path"])
120+
return module.GraphFullyFusionable(config["post_extract_process_path"])

graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.profiler import profile, record_function, ProfilerActivity
77

88

9-
class PostExtractProcess:
9+
class GraphFullyFusionable:
1010
def __init__(self, config):
1111
self.config = config
1212

samples/timm/resnet18/graph_hash.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

samples/timm/resnet18/graph_net.json

Lines changed: 0 additions & 7 deletions
This file was deleted.

samples/timm/resnet18/input_meta.py

Whitespace-only changes.

samples/timm/resnet18/input_tensor_constraints.py

Lines changed: 0 additions & 210 deletions
This file was deleted.

0 commit comments

Comments
 (0)