Skip to content

Commit de2d8c4

Browse files
authored
Submodule filter (#349)
1 parent c2280ca commit de2d8c4

File tree

7 files changed

+97
-45
lines changed

7 files changed

+97
-45
lines changed

graph_net/imp_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import importlib.util as imp
2+
3+
4+
def load_module(path, name="unamed"):
5+
spec = imp.spec_from_file_location(name, path)
6+
module = imp.module_from_spec(spec)
7+
spec.loader.exec_module(module)
8+
return module
Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
#!/bin/bash
2-
set -x
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
35

46
# input model path
57
MODEL_PATH_IN_SAMPLES=/timm/resnet18
6-
# extract subgraph 0-8, 8-16
7-
read -r -d '' json_str <<'EOF'
8+
read -r -d '' extractor_config_json_str <<EOF
89
{
9-
"output_dir": "/tmp/naive_decompose_workspace",
10-
"split_positions": [8, 16, 32],
11-
"group_head_and_tail": true,
12-
"chain_style": true
10+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
11+
"custom_extractor_config": {
12+
"output_dir": "/tmp/naive_decompose_workspace",
13+
"split_positions": [8, 16, 32],
14+
"group_head_and_tail": true,
15+
"chain_style": true
16+
}
1317
}
1418
EOF
15-
CONFIG=$(echo $json_str | base64 -w 0)
19+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
1620

1721
mkdir -p /tmp/naive_decompose_workspace
18-
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
19-
os.path.dirname(graph_net.__file__))")
20-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py --custom-extractor-config=$CONFIG
22+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG
Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
#!/bin/bash
22

3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
36
# input model path
47
MODEL_PATH_IN_SAMPLES=/timm/resnet18
5-
read -r -d '' json_str <<'EOF'
8+
read -r -d '' extractor_config_json_str <<EOF
69
{
7-
"output_dir": "/tmp/naive_decompose_workspace",
8-
"split_positions": [8, 32],
9-
"group_head_and_tail": true
10+
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
11+
"custom_extractor_config": {
12+
"output_dir": "/tmp/naive_decompose_workspace",
13+
"split_positions": [8, 32],
14+
"group_head_and_tail": true,
15+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
16+
"filter_config": {}
17+
}
1018
}
1119
EOF
12-
CONFIG=$(echo $json_str | base64 -w 0)
20+
EXTRACTOR_CONFIG=$(echo $extractor_config_json_str | base64 -w 0)
1321

1422
mkdir -p /tmp/naive_decompose_workspace
15-
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
16-
os.path.dirname(graph_net.__file__))")
17-
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --custom-extractor-path=$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py --custom-extractor-config=$CONFIG
23+
python3 -m graph_net.torch.single_device_runner --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --enable-extract True --extract-name resnet18 --dump-graph-hash-key --extractor-config=$EXTRACTOR_CONFIG

graph_net/torch/extractor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def extract(
139139
dynamic=True,
140140
mut_graph_codes=None,
141141
placeholder_auto_rename=False,
142-
custom_extractor_path: str = None,
143-
custom_extractor_config: str = None,
142+
extractor_config: dict = None,
144143
):
145144
"""
146145
Extract computation graphs from PyTorch nn.Module.
@@ -210,7 +209,11 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
210209
>>>
211210
"""
212211

212+
extractor_config = make_extractor_config(extractor_config)
213+
213214
def get_graph_extractor_maker():
215+
custom_extractor_path = extractor_config["custom_extractor_path"]
216+
custom_extractor_config = extractor_config["custom_extractor_config"]
214217
if custom_extractor_path is None:
215218
return GraphExtractor
216219
import importlib.util as imp
@@ -247,3 +250,18 @@ def decorator_or_wrapper(obj):
247250
)
248251

249252
return decorator_or_wrapper
253+
254+
255+
def make_extractor_config(extractor_config):
256+
kwargs = extractor_config if extractor_config is not None else {}
257+
return make_extractor_config_impl(**kwargs)
258+
259+
260+
def make_extractor_config_impl(
261+
custom_extractor_path: str = None, custom_extractor_config: dict = None
262+
):
263+
config = custom_extractor_config if custom_extractor_config is not None else {}
264+
return {
265+
"custom_extractor_path": custom_extractor_path,
266+
"custom_extractor_config": config,
267+
}

graph_net/torch/naive_graph_decomposer.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import os
22
import torch
3-
import json
4-
import base64
53
import shutil
64
from typing import Union, Callable
75
from graph_net.torch import utils
86
from graph_net.torch.decompose_util import convert_to_submodules_graph
97
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
8+
import graph_net.imp_util as imp_util
109

1110

1211
class GraphExtractor:
1312
def __init__(
1413
self,
15-
config_str: str,
14+
config: dict,
1615
name,
1716
dynamic,
1817
mut_graph_codes=None,
@@ -23,14 +22,16 @@ def __init__(
2322
self.dynamic = dynamic
2423
self.mut_graph_codes = mut_graph_codes
2524
self.placeholder_auto_rename = placeholder_auto_rename
26-
self.config = self.make_config(**self.convert_to_dict(config_str))
25+
self.config = self.make_config(**config)
2726

2827
def make_config(
2928
self,
3029
split_positions=(),
3130
group_head_and_tail=False,
3231
chain_style=False,
3332
output_dir="./tmp/naive_decomposer_dir",
33+
filter_path=None,
34+
filter_config=None,
3435
):
3536
for pos in split_positions:
3637
assert isinstance(
@@ -41,6 +42,8 @@ def make_config(
4142
"group_head_and_tail": group_head_and_tail,
4243
"chain_style": chain_style,
4344
"output_dir": output_dir,
45+
"filter_path": filter_path,
46+
"filter_config": filter_config if filter_config is not None else {},
4447
}
4548

4649
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
@@ -59,14 +62,6 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
5962
def get_naive_decomposer_extractor(self, submodule, seq_no):
6063
return NaiveDecomposerExtractor(self, submodule, seq_no)
6164

62-
def convert_to_dict(self, config_str):
63-
if config_str is None:
64-
return {}
65-
config_str = base64.b64decode(config_str).decode("utf-8")
66-
config = json.loads(config_str)
67-
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
68-
return config
69-
7065

7166
class NaiveDecomposerExtractor(torch.nn.Module):
7267
def __init__(self, parent_graph_extractor, submodule, seq_no):
@@ -83,9 +78,22 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
8378
placeholder_auto_rename=parent_graph_extractor.placeholder_auto_rename,
8479
workspace_path=self.parent_graph_extractor.config["output_dir"],
8580
)
81+
self.filter = self.make_filter(self.parent_graph_extractor.config)
8682

8783
def forward(self, *args):
8884
if not self.extracted:
89-
self.builtin_extractor(self.submodule, args)
85+
if self.need_extract(self.submodule, args):
86+
self.builtin_extractor(self.submodule, args)
9087
self.extracted = True
9188
return self.submodule(*args)
89+
90+
def need_extract(self, gm, sample_inputs):
91+
if self.filter is None:
92+
return True
93+
return self.filter(gm, sample_inputs)
94+
95+
def make_filter(self, config):
96+
if config["filter_path"] is None:
97+
return None
98+
module = imp_util.load_module(config["filter_path"])
99+
return module.GraphFilter(config["filter_config"])
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class GraphFilter:
2+
def __init__(self, config):
3+
self.config = config
4+
5+
def __call__(self, gm, sample_inputs):
6+
print(f"GraphFilter\n{gm.code}")
7+
return True

graph_net/torch/single_device_runner.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import sys
1010
from graph_net.torch.extractor import extract
1111
import hashlib
12+
import json
13+
import base64
1214
from contextlib import contextmanager
1315

1416

@@ -20,6 +22,15 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
2022
return model_class
2123

2224

25+
def convert_to_dict(config_str):
26+
if config_str is None:
27+
return {}
28+
config_str = base64.b64decode(config_str).decode("utf-8")
29+
config = json.loads(config_str)
30+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
31+
return config
32+
33+
2334
def _get_sha_hash(content):
2435
m = hashlib.sha256()
2536
m.update(content.encode())
@@ -63,8 +74,7 @@ def main(args):
6374
kwargs = dict(
6475
name=args.extract_name,
6576
dynamic=False,
66-
custom_extractor_path=args.custom_extractor_path,
67-
custom_extractor_config=args.custom_extractor_config,
77+
extractor_config=convert_to_dict(args.extractor_config),
6878
**dump_graph_options,
6979
)
7080
model = extract(**kwargs)(model)
@@ -118,18 +128,11 @@ def main(args):
118128
help="Extracted graph's name",
119129
)
120130
parser.add_argument(
121-
"--custom-extractor-path",
122-
type=str,
123-
required=False,
124-
default=None,
125-
help="Custom extractor python file path",
126-
)
127-
parser.add_argument(
128-
"--custom-extractor-config",
131+
"--extractor-config",
129132
type=str,
130133
required=False,
131134
default=None,
132-
help="Custom extractor configuration string",
135+
help="extractor configuration string",
133136
)
134137
args = parser.parse_args()
135138
main(args=args)

0 commit comments

Comments
 (0)