Skip to content

Commit 902654e

Browse files
committed
Implement NaiveDecomposer for paddle.
1 parent d9c4002 commit 902654e

File tree

4 files changed

+359
-48
lines changed

4 files changed

+359
-48
lines changed

graph_net/paddle/extractor.py

Lines changed: 155 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,34 @@
1111
from graph_net.paddle import utils
1212

1313

14+
# used as configuration of python -m graph_net.paddle.run_model
15+
class RunModelDecorator:
16+
def __init__(self, config):
17+
self.config = self.make_config(**config)
18+
19+
def __call__(self, model):
20+
return extract(**self.config)(model)
21+
22+
def make_config(
23+
self,
24+
name=None,
25+
dynamic=False,
26+
input_spec=None,
27+
custom_extractor_path: str = None,
28+
custom_extractor_config: dict = None,
29+
):
30+
assert name is not None
31+
return {
32+
"name": name,
33+
"dynamic": dynamic,
34+
"input_spec": input_spec,
35+
"extractor_config": {
36+
"custom_extractor_path": custom_extractor_path,
37+
"custom_extractor_config": custom_extractor_config,
38+
},
39+
}
40+
41+
1442
class GraphExtractor:
1543
def __init__(
1644
self,
@@ -26,45 +54,39 @@ def __init__(
2654
self.input_spec = input_spec
2755
assert not self.dynamic, "dynamic=True is not supported now!"
2856

29-
self.subgraph_counter = 0
30-
self.dump_path = os.environ.get("GRAPH_NET_PIR_DUMP_WORKSPACE", "/tmp")
31-
self.workspace_path = (
57+
self.num_subgraphs = 0
58+
self.num_samples_of_all_subgraphs = 0
59+
self.subgraph_idx2samples = None
60+
61+
dump_path = os.environ.get("GRAPH_NET_PIR_DUMP_WORKSPACE", "/tmp")
62+
self.dump_path = os.path.abspath(dump_path)
63+
64+
workspace_path = (
3265
workspace_path
3366
if workspace_path is not None
3467
else os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
3568
)
69+
self.workspace_path = os.path.abspath(workspace_path)
3670
if not self.workspace_path:
3771
raise EnvironmentError(
3872
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
3973
)
4074

4175
def prepare_to_extract(self, model_dump_path):
4276
os.makedirs(model_dump_path, exist_ok=True)
43-
old_flags = paddle.get_flags(
44-
[
45-
"FLAGS_logging_trunc_pir_py_code",
46-
"FLAGS_logging_pir_py_code_int_tensor_element_limit",
47-
"FLAGS_logging_pir_py_code_dir",
48-
]
49-
)
77+
new_flags = {
78+
"FLAGS_logging_trunc_pir_py_code": 1,
79+
"FLAGS_logging_pir_py_code_int_tensor_element_limit": 64,
80+
"FLAGS_logging_pir_py_code_dir": model_dump_path,
81+
}
82+
old_flags = paddle.get_flags(list(new_flags.keys()))
83+
5084
print(f"Set pir dumping path to {model_dump_path}")
51-
paddle.set_flags(
52-
{
53-
"FLAGS_logging_trunc_pir_py_code": 1,
54-
"FLAGS_logging_pir_py_code_int_tensor_element_limit": 64,
55-
"FLAGS_logging_pir_py_code_dir": model_dump_path,
56-
}
57-
)
85+
paddle.set_flags(new_flags)
5886
return old_flags
5987

60-
def write_to_file(self, filepath, content):
61-
print(f"Write to {filepath}")
62-
with open(filepath, "w") as f:
63-
f.write(content)
64-
65-
def __call__(self, **input_dict):
66-
# 1. Get model dump path
67-
model_dump_path = os.path.join(self.dump_path, self.name)
88+
def run_model_with_dump_enabled(self, model_dump_path, **input_dict):
89+
# Get model dump path
6890
old_flags = self.prepare_to_extract(model_dump_path)
6991

7092
if self.input_spec is None:
@@ -74,13 +96,19 @@ def __call__(self, **input_dict):
7496
if isinstance(value, paddle.Tensor)
7597
]
7698

77-
# 2. Run the model to dump pir programs
99+
# Run the static model
78100
static_model = paddle.jit.to_static(
79101
self.model, input_spec=self.input_spec, full_graph=True
80102
)
81103
static_model(**input_dict)
82104

83-
# 3. Convert pir programs to graphnet samples
105+
# Restore the environment
106+
paddle.set_flags(old_flags)
107+
return static_model
108+
109+
def translate_pir_program_to_sample_codes(
110+
self, model_dump_path, split_positions=None
111+
):
84112
ir_programs_path = os.path.join(model_dump_path, "exec_programs.py")
85113
example_inputs_path = os.path.join(
86114
model_dump_path, "programs_example_input_tensor_meta.py"
@@ -92,43 +120,108 @@ def __call__(self, **input_dict):
92120
example_inputs_path
93121
), f"{example_inputs_path} is not a regular file."
94122

123+
# Arguments for graph decomposer
124+
op_example_inputs_path = (
125+
os.path.join(model_dump_path, "op_example_input_tensor_meta.py")
126+
if split_positions
127+
else None
128+
)
95129
graphnet_samples = generate_samples(
96130
model_name=self.name,
97131
ir_programs=ir_programs_path,
98132
example_inputs=example_inputs_path,
133+
op_example_inputs=op_example_inputs_path,
134+
split_positions=split_positions,
99135
eval_mode=True,
100136
)
101137

102-
# 4. Save to model_path
138+
self.subgraph_idx2samples = {}
139+
for sample in graphnet_samples:
140+
if sample.subgraph_idx not in self.subgraph_idx2samples.keys():
141+
self.subgraph_idx2samples[sample.subgraph_idx] = []
142+
self.subgraph_idx2samples[sample.subgraph_idx].append(sample)
143+
144+
self.num_subgraphs = len(self.subgraph_idx2samples)
145+
self.num_samples_of_all_subgraphs = len(graphnet_samples)
146+
assert self.num_subgraphs > 0
147+
return self.subgraph_idx2samples
148+
149+
def write_sample_to_file(self, subgraph_path, sample):
150+
def write_to_file(filepath, content):
151+
print(f"Write to {filepath}")
152+
with open(filepath, "w") as f:
153+
f.write(content)
154+
155+
if not os.path.exists(subgraph_path):
156+
os.makedirs(subgraph_path, exist_ok=True)
157+
write_to_file(f"{subgraph_path}/model.py", sample.model)
158+
write_to_file(f"{subgraph_path}/weight_meta.py", sample.weight_meta)
159+
write_to_file(f"{subgraph_path}/input_meta.py", sample.input_meta)
160+
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
161+
json.dump(sample.metadata, f, indent=4)
162+
163+
def __call__(self, **input_dict):
164+
# 1. Run the model to dump pir programs
165+
model_dump_path = os.path.join(self.dump_path, self.name)
166+
static_model = self.run_model_with_dump_enabled(model_dump_path, **input_dict)
167+
168+
# 2. Convert pir programs to graphnet samples
169+
self.translate_pir_program_to_sample_codes(
170+
model_dump_path, split_positions=None
171+
)
172+
173+
# 3. Save to model_path
103174
model_path = os.path.join(self.workspace_path, self.name)
104-
self.subgraph_counter = len(graphnet_samples)
105-
for i, sample in enumerate(graphnet_samples):
106-
subgraph_path = (
107-
model_path
108-
if self.subgraph_counter == 1
109-
else os.path.join(model_path, f"subgraph_{i}")
110-
)
111-
if not os.path.exists(subgraph_path):
112-
os.makedirs(subgraph_path, exist_ok=True)
113-
self.write_to_file(f"{subgraph_path}/model.py", sample.model)
114-
self.write_to_file(f"{subgraph_path}/weight_meta.py", sample.weight_meta)
115-
self.write_to_file(f"{subgraph_path}/input_meta.py", sample.input_meta)
116-
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
117-
json.dump(sample.metadata, f, indent=4)
175+
for subgraph_idx, samples in self.subgraph_idx2samples.items():
176+
assert len(samples) == 1
177+
if self.num_samples_of_all_subgraphs == 1:
178+
subgraph_path = model_path
179+
else:
180+
subgraph_path = os.path.join(model_path, f"subgraph_{subgraph_idx}")
181+
self.write_sample_to_file(subgraph_path, samples[0])
118182

119183
print(
120184
f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}"
121185
)
122-
123-
# 5. Restore the environment
124-
paddle.set_flags(old_flags)
125186
return static_model
126187

127188

128-
def extract(name, dynamic=False, input_spec=None):
189+
def extract(name, dynamic=False, input_spec=None, extractor_config: dict = None):
190+
"""
191+
Extract computation graphs from PaddlePaddle nn.Layer.
192+
The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
193+
194+
Args:
195+
name (str): The name of the model, used as the directory name for saving.
196+
dynamic (bool): Enable dynamic shape support in paddle.jit.to_static.
197+
input_spec (list[InputSpec] | tuple[InputSpec]): InputSpec for input tensors, which includes tensor's name, shape and dtype.
198+
When dynamic is False, input_spec can be inferred automatically.
199+
200+
Returns:
201+
wrapper or decorator
202+
"""
203+
204+
extractor_config = make_extractor_config(extractor_config)
205+
206+
def get_graph_extractor_maker():
207+
custom_extractor_path = extractor_config["custom_extractor_path"]
208+
custom_extractor_config = extractor_config["custom_extractor_config"]
209+
if custom_extractor_path is None:
210+
return GraphExtractor
211+
import importlib.util as imp
212+
213+
print(f"Import graph_extractor from {custom_extractor_path}")
214+
# import custom_extractor_path as graph_extractor
215+
spec = imp.spec_from_file_location("graph_extractor", custom_extractor_path)
216+
graph_extractor = imp.module_from_spec(spec)
217+
spec.loader.exec_module(graph_extractor)
218+
cls = graph_extractor.GraphExtractor
219+
return lambda *args, **kwargs: cls(custom_extractor_config, *args, **kwargs)
220+
129221
def wrapper(model: paddle.nn.Layer):
130222
assert isinstance(model, paddle.nn.Layer), f"{type(model)=}"
131-
return GraphExtractor(model, name, dynamic, input_spec)
223+
extractor = get_graph_extractor_maker()(model, name, dynamic, input_spec)
224+
return extractor
132225

133226
def decorator(module_class):
134227
def constructor(*args, **kwargs):
@@ -147,3 +240,18 @@ def decorator_or_wrapper(obj):
147240
)
148241

149242
return decorator_or_wrapper
243+
244+
245+
def make_extractor_config(extractor_config):
246+
kwargs = extractor_config if extractor_config is not None else {}
247+
return make_extractor_config_impl(**kwargs)
248+
249+
250+
def make_extractor_config_impl(
251+
custom_extractor_path: str = None, custom_extractor_config: dict = None
252+
):
253+
config = custom_extractor_config if custom_extractor_config is not None else {}
254+
return {
255+
"custom_extractor_path": custom_extractor_path,
256+
"custom_extractor_config": config,
257+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import os
2+
from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor
3+
4+
5+
class GraphExtractor:
6+
def __init__(
7+
self,
8+
config: dict,
9+
model,
10+
name,
11+
dynamic,
12+
input_spec=None,
13+
):
14+
self.model = model
15+
self.name = name
16+
self.dynamic = dynamic
17+
self.input_spec = input_spec
18+
self.config = self.make_config(**config)
19+
20+
def make_config(
21+
self,
22+
split_positions=(),
23+
group_head_and_tail=False,
24+
chain_style=False,
25+
output_dir="./tmp/naive_decomposer_dir",
26+
):
27+
for pos in split_positions:
28+
assert isinstance(
29+
pos, int
30+
), f"split_positions should be list of int, {split_positions=}"
31+
return {
32+
"split_positions": split_positions,
33+
"group_head_and_tail": group_head_and_tail,
34+
"chain_style": chain_style,
35+
"output_dir": output_dir,
36+
}
37+
38+
def __call__(self, **input_dict):
39+
extracted_model = self.get_naive_decomposer_extractor()(**input_dict)
40+
return extracted_model
41+
42+
def get_naive_decomposer_extractor(self):
43+
return NaiveDecomposerExtractor(self)
44+
45+
46+
class NaiveDecomposerExtractor:
47+
def __init__(self, parent_graph_extractor):
48+
super().__init__()
49+
self.parent_graph_extractor = parent_graph_extractor
50+
self.extracted = False
51+
self.builtin_extractor = BuiltinGraphExtractor(
52+
model=parent_graph_extractor.model,
53+
name=parent_graph_extractor.name,
54+
dynamic=parent_graph_extractor.dynamic,
55+
input_spec=parent_graph_extractor.input_spec,
56+
workspace_path=self.parent_graph_extractor.config["output_dir"],
57+
)
58+
self.split_positions = self.parent_graph_extractor.config["split_positions"]
59+
self.post_process = self.make_post_process(self.parent_graph_extractor.config)
60+
61+
def do_extract(self, **input_dict):
62+
# 1. Run the model to dump pir programs
63+
model_dump_path = os.path.join(
64+
self.builtin_extractor.dump_path, self.builtin_extractor.name
65+
)
66+
static_model = self.builtin_extractor.run_model_with_dump_enabled(
67+
model_dump_path, **input_dict
68+
)
69+
70+
# 2. Convert pir programs to graphnet samples
71+
self.builtin_extractor.translate_pir_program_to_sample_codes(
72+
model_dump_path, split_positions=self.split_positions
73+
)
74+
75+
# 3. Save to model_path
76+
self.subgraph_path_list = []
77+
model_path = os.path.join(
78+
self.builtin_extractor.workspace_path, self.builtin_extractor.name
79+
)
80+
for (
81+
subgraph_idx,
82+
samples,
83+
) in self.builtin_extractor.subgraph_idx2samples.items():
84+
for seq_idx in range(len(samples)):
85+
if (
86+
self.builtin_extractor.num_samples_of_all_subgraphs == 1
87+
and len(samples) == 1
88+
):
89+
subgraph_path = model_path
90+
elif len(samples) == 1:
91+
subgraph_path = os.path.join(model_path, f"subgraph_{subgraph_idx}")
92+
else:
93+
subgraph_path = os.path.join(
94+
model_path, f"subgraph_{subgraph_idx}_{seq_idx}"
95+
)
96+
self.subgraph_path_list.append(subgraph_path)
97+
self.builtin_extractor.write_sample_to_file(
98+
subgraph_path, samples[seq_idx]
99+
)
100+
print(
101+
f"Graph and tensors for '{self.builtin_extractor.name}' extracted successfully to: {model_path}"
102+
)
103+
return static_model
104+
105+
def __call__(self, **input_dict):
106+
extracted_model = None
107+
if not self.extracted:
108+
extracted_model = self.do_extract(**input_dict)
109+
self.extracted = True
110+
# if self.extracted:
111+
# for subgraph_path in self.subgraph_path_list:
112+
# self.post_process(subgraph_path)
113+
return extracted_model
114+
115+
def make_post_process(self, config):
116+
return None
117+
# if config["post_process_path"] is None:
118+
# return None
119+
# module = imp_util.load_module(config["post_process_path"])
120+
# return module.PostExtractProcess(config["post_process_config"])

0 commit comments

Comments
 (0)