Skip to content

Commit 9546015

Browse files
committed
Implement NaiveDecomposer for paddle.
1 parent d9c4002 commit 9546015

File tree

4 files changed

+280
-30
lines changed

4 files changed

+280
-30
lines changed

graph_net/paddle/extractor.py

Lines changed: 144 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,34 @@
1111
from graph_net.paddle import utils
1212

1313

14+
# used as configuration of python -m graph_net.paddle.run_model
15+
class RunModelDecorator:
16+
def __init__(self, config):
17+
self.config = self.make_config(**config)
18+
19+
def __call__(self, model):
20+
return extract(**self.config)(model)
21+
22+
def make_config(
23+
self,
24+
name=None,
25+
dynamic=False,
26+
input_spec=None,
27+
custom_extractor_path: str = None,
28+
custom_extractor_config: dict = None,
29+
):
30+
assert name is not None
31+
return {
32+
"name": name,
33+
"dynamic": dynamic,
34+
"input_spec": input_spec,
35+
"extractor_config": {
36+
"custom_extractor_path": custom_extractor_path,
37+
"custom_extractor_config": custom_extractor_config,
38+
},
39+
}
40+
41+
1442
class GraphExtractor:
1543
def __init__(
1644
self,
@@ -26,7 +54,10 @@ def __init__(
2654
self.input_spec = input_spec
2755
assert not self.dynamic, "dynamic=True is not supported now!"
2856

29-
self.subgraph_counter = 0
57+
self.num_subgraphs = 0
58+
self.num_samples_of_all_subgraphs = 0
59+
self.subgraph_idx2samples = None
60+
3061
self.dump_path = os.environ.get("GRAPH_NET_PIR_DUMP_WORKSPACE", "/tmp")
3162
self.workspace_path = (
3263
workspace_path
@@ -57,30 +88,23 @@ def prepare_to_extract(self, model_dump_path):
5788
)
5889
return old_flags
5990

60-
def write_to_file(self, filepath, content):
61-
print(f"Write to {filepath}")
62-
with open(filepath, "w") as f:
63-
f.write(content)
64-
65-
def __call__(self, **input_dict):
66-
# 1. Get model dump path
67-
model_dump_path = os.path.join(self.dump_path, self.name)
68-
old_flags = self.prepare_to_extract(model_dump_path)
69-
91+
def run_model(self, **input_dict):
7092
if self.input_spec is None:
7193
self.input_spec = [
7294
paddle.static.InputSpec(value.shape, value.dtype, name=name)
7395
for name, value in input_dict.items()
7496
if isinstance(value, paddle.Tensor)
7597
]
7698

77-
# 2. Run the model to dump pir programs
7899
static_model = paddle.jit.to_static(
79100
self.model, input_spec=self.input_spec, full_graph=True
80101
)
81102
static_model(**input_dict)
103+
return static_model
82104

83-
# 3. Convert pir programs to graphnet samples
105+
def translate_pir_program_to_sample_codes(
106+
self, model_dump_path, split_positions=None
107+
):
84108
ir_programs_path = os.path.join(model_dump_path, "exec_programs.py")
85109
example_inputs_path = os.path.join(
86110
model_dump_path, "programs_example_input_tensor_meta.py"
@@ -92,29 +116,73 @@ def __call__(self, **input_dict):
92116
example_inputs_path
93117
), f"{example_inputs_path} is not a regular file."
94118

119+
# Arguments for graph decomposer
120+
op_example_inputs_path = (
121+
os.path.join(model_dump_path, "op_example_input_tensor_meta.py")
122+
if split_positions
123+
else None
124+
)
125+
split_positions = (
126+
",".join(map(str, split_positions))
127+
if split_positions and isinstance(split_positions, (tuple, list))
128+
else split_positions
129+
)
130+
95131
graphnet_samples = generate_samples(
96132
model_name=self.name,
97133
ir_programs=ir_programs_path,
98134
example_inputs=example_inputs_path,
135+
op_example_inputs=op_example_inputs_path,
136+
split_positions=split_positions,
99137
eval_mode=True,
100138
)
101139

140+
self.subgraph_idx2samples = {}
141+
for sample in graphnet_samples:
142+
if sample.subgraph_idx not in self.subgraph_idx2samples.keys():
143+
self.subgraph_idx2samples[sample.subgraph_idx] = []
144+
self.subgraph_idx2samples[sample.subgraph_idx].append(sample)
145+
146+
self.num_subgraphs = len(self.subgraph_idx2samples)
147+
self.num_samples_of_all_subgraphs = len(graphnet_samples)
148+
return self.subgraph_idx2samples
149+
150+
def write_sample_to_file(self, subgraph_path, sample):
151+
def write_to_file(filepath, content):
152+
print(f"Write to {filepath}")
153+
with open(filepath, "w") as f:
154+
f.write(content)
155+
156+
if not os.path.exists(subgraph_path):
157+
os.makedirs(subgraph_path, exist_ok=True)
158+
write_to_file(f"{subgraph_path}/model.py", sample.model)
159+
write_to_file(f"{subgraph_path}/weight_meta.py", sample.weight_meta)
160+
write_to_file(f"{subgraph_path}/input_meta.py", sample.input_meta)
161+
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
162+
json.dump(sample.metadata, f, indent=4)
163+
164+
def __call__(self, **input_dict):
165+
# 1. Get model dump path
166+
model_dump_path = os.path.join(self.dump_path, self.name)
167+
old_flags = self.prepare_to_extract(model_dump_path)
168+
169+
# 2. Run the model to dump pir programs
170+
static_model = self.run_model(**input_dict)
171+
172+
# 3. Convert pir programs to graphnet samples
173+
self.translate_pir_program_to_sample_codes(
174+
model_dump_path, split_positions=None
175+
)
176+
102177
# 4. Save to model_path
103178
model_path = os.path.join(self.workspace_path, self.name)
104-
self.subgraph_counter = len(graphnet_samples)
105-
for i, sample in enumerate(graphnet_samples):
106-
subgraph_path = (
107-
model_path
108-
if self.subgraph_counter == 1
109-
else os.path.join(model_path, f"subgraph_{i}")
110-
)
111-
if not os.path.exists(subgraph_path):
112-
os.makedirs(subgraph_path, exist_ok=True)
113-
self.write_to_file(f"{subgraph_path}/model.py", sample.model)
114-
self.write_to_file(f"{subgraph_path}/weight_meta.py", sample.weight_meta)
115-
self.write_to_file(f"{subgraph_path}/input_meta.py", sample.input_meta)
116-
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
117-
json.dump(sample.metadata, f, indent=4)
179+
for subgraph_idx, samples in self.subgraph_idx2samples.items():
180+
assert len(samples) == 1
181+
if self.num_samples_of_all_subgraphs == 1:
182+
subgraph_path = model_path
183+
else:
184+
subgraph_path = os.path.join(model_path, f"subgraph_{subgraph_idx}")
185+
self.write_sample_to_file(subgraph_path, samples[0])
118186

119187
print(
120188
f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}"
@@ -125,10 +193,42 @@ def __call__(self, **input_dict):
125193
return static_model
126194

127195

128-
def extract(name, dynamic=False, input_spec=None):
196+
def extract(name, dynamic=False, input_spec=None, extractor_config: dict = None):
197+
"""
198+
Extract computation graphs from PaddlePaddle nn.Layer.
199+
The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
200+
201+
Args:
202+
name (str): The name of the model, used as the directory name for saving.
203+
dynamic (bool): Enable dynamic shape support in paddle.jit.to_static.
204+
input_spec (list[InputSpec] | tuple[InputSpec]): InputSpec for input tensors, which includes tensor's name, shape and dtype.
205+
When dynamic is False, input_spec can be inferred automatically.
206+
207+
Returns:
208+
wrapper or decorator
209+
"""
210+
211+
extractor_config = make_extractor_config(extractor_config)
212+
213+
def get_graph_extractor_maker():
214+
custom_extractor_path = extractor_config["custom_extractor_path"]
215+
custom_extractor_config = extractor_config["custom_extractor_config"]
216+
if custom_extractor_path is None:
217+
return GraphExtractor
218+
import importlib.util as imp
219+
220+
print(f"Import graph_extractor from {custom_extractor_path}")
221+
# import custom_extractor_path as graph_extractor
222+
spec = imp.spec_from_file_location("graph_extractor", custom_extractor_path)
223+
graph_extractor = imp.module_from_spec(spec)
224+
spec.loader.exec_module(graph_extractor)
225+
cls = graph_extractor.GraphExtractor
226+
return lambda *args, **kwargs: cls(custom_extractor_config, *args, **kwargs)
227+
129228
def wrapper(model: paddle.nn.Layer):
130229
assert isinstance(model, paddle.nn.Layer), f"{type(model)=}"
131-
return GraphExtractor(model, name, dynamic, input_spec)
230+
extractor = get_graph_extractor_maker()(model, name, dynamic, input_spec)
231+
return extractor
132232

133233
def decorator(module_class):
134234
def constructor(*args, **kwargs):
@@ -147,3 +247,18 @@ def decorator_or_wrapper(obj):
147247
)
148248

149249
return decorator_or_wrapper
250+
251+
252+
def make_extractor_config(extractor_config):
253+
kwargs = extractor_config if extractor_config is not None else {}
254+
return make_extractor_config_impl(**kwargs)
255+
256+
257+
def make_extractor_config_impl(
258+
custom_extractor_path: str = None, custom_extractor_config: dict = None
259+
):
260+
config = custom_extractor_config if custom_extractor_config is not None else {}
261+
return {
262+
"custom_extractor_path": custom_extractor_path,
263+
"custom_extractor_config": config,
264+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
class GraphExtractor:
2+
def __init__(
3+
self,
4+
config: dict,
5+
model,
6+
name,
7+
dynamic,
8+
input_spec=None,
9+
):
10+
self.subgraph_counter = 0
11+
self.model = model
12+
self.name = name
13+
self.dynamic = dynamic
14+
self.input_spec = input_spec
15+
self.config = self.make_config(**config)
16+
17+
def make_config(
18+
self,
19+
split_positions=(),
20+
group_head_and_tail=False,
21+
chain_style=False,
22+
output_dir="./tmp/naive_decomposer_dir",
23+
filter_path=None,
24+
filter_config=None,
25+
):
26+
for pos in split_positions:
27+
assert isinstance(
28+
pos, int
29+
), f"split_positions should be list of int, {split_positions=}"
30+
return {
31+
"split_positions": split_positions,
32+
"group_head_and_tail": group_head_and_tail,
33+
"chain_style": chain_style,
34+
"output_dir": output_dir,
35+
"filter_path": filter_path,
36+
"filter_config": filter_config if filter_config is not None else {},
37+
}
38+
39+
def __call__(self, **input_dict):
40+
config = {
41+
k: v
42+
for k, v in self.config.items()
43+
if k in {"split_positions", "group_head_and_tail", "chain_style"}
44+
}
45+
static_model = paddle.jit.to_static(
46+
self.model, input_spec=self.input_spec, full_graph=True
47+
)
48+
static_model(**input_dict)
49+
return static_model
50+
51+
def get_naive_decomposer_extractor(self, submodule, seq_no):
52+
return NaiveDecomposerExtractor(self, submodule, seq_no)

graph_net/paddle/run_model.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import sys
2+
import json
3+
import base64
4+
import argparse
5+
import importlib.util
6+
from typing import Type
7+
8+
import paddle
9+
from graph_net.paddle import utils
10+
11+
12+
def load_class_from_file(file_path: str, class_name: str):
13+
print(f"Load {class_name} from {file_path}")
14+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
15+
unnamed = importlib.util.module_from_spec(spec)
16+
spec.loader.exec_module(unnamed)
17+
model_class = getattr(unnamed, class_name, None)
18+
return model_class
19+
20+
21+
def get_input_dict(model_path):
22+
inputs_params = utils.load_converted_from_text(f"{model_path}")
23+
params = inputs_params["weight_info"]
24+
inputs = inputs_params["input_info"]
25+
26+
params.update(inputs)
27+
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
28+
return state_dict
29+
30+
31+
def _convert_to_dict(config_str):
32+
if config_str is None:
33+
return {}
34+
config_str = base64.b64decode(config_str).decode("utf-8")
35+
config = json.loads(config_str)
36+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
37+
return config
38+
39+
40+
def _get_decorator(args):
41+
if args.decorator_config is None:
42+
return lambda model: model
43+
decorator_config = _convert_to_dict(args.decorator_config)
44+
if "decorator_path" not in decorator_config:
45+
return lambda model: model
46+
decorator_class = load_class_from_file(
47+
decorator_config["decorator_path"], class_name="RunModelDecorator"
48+
)
49+
return decorator_class(decorator_config.get("decorator_config", {}))
50+
51+
52+
def main(args):
53+
model_path = args.model_path
54+
model_class = load_class_from_file(
55+
f"{model_path}/model.py", class_name="GraphModule"
56+
)
57+
assert model_class is not None
58+
model = model_class()
59+
print(f"{model_path=}")
60+
61+
model = _get_decorator(args)(model)
62+
input_dict = get_input_dict(args.model_path)
63+
64+
model(**input_dict)
65+
66+
67+
if __name__ == "__main__":
68+
parser = argparse.ArgumentParser(description="load and run model")
69+
parser.add_argument(
70+
"--model-path",
71+
type=str,
72+
required=True,
73+
help="Path to folder e.g '../../paddle_samples/PaddleX/ResNet18'",
74+
)
75+
parser.add_argument(
76+
"--decorator-config",
77+
type=str,
78+
required=False,
79+
default=None,
80+
help="decorator configuration string",
81+
)
82+
args = parser.parse_args()
83+
main(args=args)

graph_net/torch/extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def extract(
178178
dynamic (bool): Enable dynamic shape support in torch.compile.
179179
180180
Returns:
181-
wrapper or decorector
181+
wrapper or decorator
182182
183183
Examples:
184184
>>> # wrapper style:

0 commit comments

Comments
 (0)