Skip to content

Commit d4ad054

Browse files
committed
Implement NaiveDecomposer for paddle.
1 parent d9c4002 commit d4ad054

File tree

3 files changed

+212
-2
lines changed

3 files changed

+212
-2
lines changed

graph_net/paddle/extractor.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,34 @@
1111
from graph_net.paddle import utils
1212

1313

14+
# used as configuration of python -m graph_net.paddle.run_model
15+
class RunModelDecorator:
16+
def __init__(self, config):
17+
self.config = self.make_config(**config)
18+
19+
def __call__(self, model):
20+
return extract(**self.config)(model)
21+
22+
def make_config(
23+
self,
24+
name=None,
25+
dynamic=True,
26+
input_spec=None,
27+
custom_extractor_path: str = None,
28+
custom_extractor_config: dict = None,
29+
):
30+
assert name is not None
31+
return {
32+
"name": name,
33+
"dynamic": dynamic,
34+
"input_spec": input_spec,
35+
"extractor_config": {
36+
"custom_extractor_path": custom_extractor_path,
37+
"custom_extractor_config": custom_extractor_config,
38+
},
39+
}
40+
41+
1442
class GraphExtractor:
1543
def __init__(
1644
self,
@@ -125,10 +153,42 @@ def __call__(self, **input_dict):
125153
return static_model
126154

127155

128-
def extract(name, dynamic=False, input_spec=None):
156+
def extract(name, dynamic=False, input_spec=None, extractor_config: dict = None):
157+
"""
158+
Extract computation graphs from PaddlePaddle nn.Layer.
159+
The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
160+
161+
Args:
162+
name (str): The name of the model, used as the directory name for saving.
163+
dynamic (bool): Enable dynamic shape support in paddle.jit.to_static.
164+
input_spec (list[InputSpec] | tuple[InputSpec]): InputSpec for input tensors, which includes tensor's name, shape and dtype.
165+
When dynamic is False, input_spec can be inferred automatically.
166+
167+
Returns:
168+
wrapper or decorator
169+
"""
170+
171+
extractor_config = make_extractor_config(extractor_config)
172+
173+
def get_graph_extractor_maker():
174+
custom_extractor_path = extractor_config["custom_extractor_path"]
175+
custom_extractor_config = extractor_config["custom_extractor_config"]
176+
if custom_extractor_path is None:
177+
return GraphExtractor
178+
import importlib.util as imp
179+
180+
print(f"Import graph_extractor from {custom_extractor_path}")
181+
# import custom_extractor_path as graph_extractor
182+
spec = imp.spec_from_file_location("graph_extractor", custom_extractor_path)
183+
graph_extractor = imp.module_from_spec(spec)
184+
spec.loader.exec_module(graph_extractor)
185+
cls = graph_extractor.GraphExtractor
186+
return lambda *args, **kwargs: cls(custom_extractor_config, *args, **kwargs)
187+
129188
def wrapper(model: paddle.nn.Layer):
130189
assert isinstance(model, paddle.nn.Layer), f"{type(model)=}"
131-
return GraphExtractor(model, name, dynamic, input_spec)
190+
extractor = get_graph_extractor_maker()(model, name, dynamic, input_spec)
191+
return extractor
132192

133193
def decorator(module_class):
134194
def constructor(*args, **kwargs):
@@ -147,3 +207,18 @@ def decorator_or_wrapper(obj):
147207
)
148208

149209
return decorator_or_wrapper
210+
211+
212+
def make_extractor_config(extractor_config):
213+
kwargs = extractor_config if extractor_config is not None else {}
214+
return make_extractor_config_impl(**kwargs)
215+
216+
217+
def make_extractor_config_impl(
218+
custom_extractor_path: str = None, custom_extractor_config: dict = None
219+
):
220+
config = custom_extractor_config if custom_extractor_config is not None else {}
221+
return {
222+
"custom_extractor_path": custom_extractor_path,
223+
"custom_extractor_config": config,
224+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
class GraphExtractor:
2+
def __init__(
3+
self,
4+
config: dict,
5+
model,
6+
name,
7+
dynamic,
8+
input_spec=None,
9+
):
10+
self.subgraph_counter = 0
11+
self.model = model
12+
self.name = name
13+
self.dynamic = dynamic
14+
self.input_spec = input_spec
15+
self.config = self.make_config(**config)
16+
17+
def make_config(
18+
self,
19+
split_positions=(),
20+
group_head_and_tail=False,
21+
chain_style=False,
22+
output_dir="./tmp/naive_decomposer_dir",
23+
filter_path=None,
24+
filter_config=None,
25+
):
26+
for pos in split_positions:
27+
assert isinstance(
28+
pos, int
29+
), f"split_positions should be list of int, {split_positions=}"
30+
return {
31+
"split_positions": split_positions,
32+
"group_head_and_tail": group_head_and_tail,
33+
"chain_style": chain_style,
34+
"output_dir": output_dir,
35+
"filter_path": filter_path,
36+
"filter_config": filter_config if filter_config is not None else {},
37+
}
38+
39+
def __call__(self, **input_dict):
40+
config = {
41+
k: v
42+
for k, v in self.config.items()
43+
if k in {"split_positions", "group_head_and_tail", "chain_style"}
44+
}
45+
static_model = paddle.jit.to_static(
46+
self.model, input_spec=self.input_spec, full_graph=True
47+
)
48+
static_model(**input_dict)
49+
return static_model
50+
51+
def get_naive_decomposer_extractor(self, submodule, seq_no):
52+
return NaiveDecomposerExtractor(self, submodule, seq_no)

graph_net/paddle/run_model.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import sys
2+
import json
3+
import base64
4+
import argparse
5+
import importlib.util
6+
from typing import Type
7+
8+
import paddle
9+
from graph_net.paddle import utils
10+
11+
12+
def load_class_from_file(file_path: str, class_name: str):
13+
print(f"Load {class_name} from {file_path}")
14+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
15+
unnamed = importlib.util.module_from_spec(spec)
16+
spec.loader.exec_module(unnamed)
17+
model_class = getattr(unnamed, class_name, None)
18+
return model_class
19+
20+
21+
def get_input_dict(model_path):
22+
inputs_params = utils.load_converted_from_text(f"{model_path}")
23+
params = inputs_params["weight_info"]
24+
inputs = inputs_params["input_info"]
25+
26+
params.update(inputs)
27+
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
28+
return state_dict
29+
30+
31+
def _convert_to_dict(config_str):
32+
if config_str is None:
33+
return {}
34+
config_str = base64.b64decode(config_str).decode("utf-8")
35+
config = json.loads(config_str)
36+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
37+
return config
38+
39+
40+
def _get_decorator(args):
41+
if args.decorator_config is None:
42+
return lambda model: model
43+
decorator_config = _convert_to_dict(args.decorator_config)
44+
if "decorator_path" not in decorator_config:
45+
return lambda model: model
46+
decorator_class = load_class_from_file(
47+
decorator_config["decorator_path"], class_name="RunModelDecorator"
48+
)
49+
return decorator_class(decorator_config.get("decorator_config", {}))
50+
51+
52+
def main(args):
53+
model_path = args.model_path
54+
model_class = load_class_from_file(
55+
f"{model_path}/model.py", class_name="GraphModule"
56+
)
57+
assert model_class is not None
58+
model = model_class()
59+
print(f"{model_path=}")
60+
61+
model = _get_decorator(args)(model)
62+
input_dict = get_input_dict(args.model_path)
63+
64+
model(**input_dict)
65+
66+
67+
if __name__ == "__main__":
68+
parser = argparse.ArgumentParser(description="load and run model")
69+
parser.add_argument(
70+
"--model-path",
71+
type=str,
72+
required=True,
73+
help="Path to folder e.g '../../paddle_samples/PaddleX/ResNet18'",
74+
)
75+
parser.add_argument(
76+
"--decorator-config",
77+
type=str,
78+
required=False,
79+
default=None,
80+
help="decorator configuration string",
81+
)
82+
args = parser.parse_args()
83+
main(args=args)

0 commit comments

Comments
 (0)