Skip to content

Commit 7ee8b99

Browse files
authored
[Feature Enhancement] Implement extractor and navie_decomposer for paddle. (#376)
* Implement extractor for paddle. * Implement NaiveDecomposer for paddle. * Generate GraphModuleWrapper class to support re-extract from a GraphNet sample and retain the parameter information. * Add Athena to requirements.
1 parent 8768f95 commit 7ee8b99

File tree

7 files changed

+592
-102
lines changed

7 files changed

+592
-102
lines changed

graph_net/paddle/extractor.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import os
2+
import json
3+
import importlib.util
4+
5+
import paddle
6+
from athena.graphnet_samples import GraphnetSample, RunGeneration
7+
from graph_net import imp_util
8+
from graph_net.paddle import utils
9+
10+
11+
def load_class_from_file(file_path: str, class_name: str):
12+
print(f"Load {class_name} from {file_path}")
13+
module = imp_util.load_module(file_path, "unnamed")
14+
model_class = getattr(module, class_name, None)
15+
return model_class
16+
17+
18+
def write_to_file(filepath, content):
19+
print(f"Write to {filepath}")
20+
with open(filepath, "w") as f:
21+
f.write(content)
22+
23+
24+
def generate_model_wrapper_class(model_dump_path, data_arg_names):
25+
graph_module_wrapper_class_template = """
26+
import paddle
27+
28+
class GraphModuleWrapper(paddle.nn.Layer):
29+
def __init__(self, graph_module):
30+
super().__init__()
31+
self.graph_module = graph_module
32+
33+
def set_parameters(self, **kwargs):
34+
for name, value in kwargs.items():
35+
if isinstance(value, paddle.nn.parameter.Parameter):
36+
setattr(self, name, value)
37+
38+
def forward(self, ${DATA_ARG_NAMES}):
39+
param_dict = { name: param for name, param in self.named_parameters() }
40+
outputs = self.graph_module(${DATA_ARG_VALUE_PAIRS}, **param_dict)
41+
return outputs
42+
"""
43+
44+
data_arg_value_pairs = [f"{name}={name}" for name in data_arg_names]
45+
graph_module_wrapper_class_code_str = graph_module_wrapper_class_template.replace(
46+
"${DATA_ARG_NAMES}", ", ".join(data_arg_names)
47+
).replace("${DATA_ARG_VALUE_PAIRS}", ", ".join(data_arg_value_pairs))
48+
print(graph_module_wrapper_class_code_str)
49+
50+
file_path = os.path.join(model_dump_path, "graph_module_wrapper.py")
51+
write_to_file(file_path, graph_module_wrapper_class_code_str)
52+
model_class = load_class_from_file(
53+
file_path=file_path, class_name="GraphModuleWrapper"
54+
)
55+
return model_class
56+
57+
58+
# used as configuration of python -m graph_net.paddle.run_model
59+
class RunModelDecorator:
60+
def __init__(self, config):
61+
self.config = self.make_config(**config)
62+
63+
def __call__(self, model):
64+
return extract(**self.config)(model)
65+
66+
def make_config(
67+
self,
68+
name=None,
69+
dynamic=False,
70+
input_spec=None,
71+
custom_extractor_path: str = None,
72+
custom_extractor_config: dict = None,
73+
):
74+
assert name is not None
75+
return {
76+
"name": name,
77+
"dynamic": dynamic,
78+
"input_spec": input_spec,
79+
"extractor_config": {
80+
"custom_extractor_path": custom_extractor_path,
81+
"custom_extractor_config": custom_extractor_config,
82+
},
83+
}
84+
85+
86+
class GraphExtractor:
87+
def __init__(
88+
self,
89+
model,
90+
name,
91+
dynamic=False,
92+
input_spec=None,
93+
workspace_path=None,
94+
):
95+
self.model = model
96+
self.name = name
97+
self.dynamic = dynamic
98+
self.input_spec = input_spec
99+
assert not self.dynamic, "dynamic=True is not supported now!"
100+
101+
self.num_subgraphs = 0
102+
self.num_samples_of_all_subgraphs = 0
103+
self.subgraph_idx2samples = None
104+
105+
dump_path = os.environ.get("GRAPH_NET_PIR_DUMP_WORKSPACE", "/tmp")
106+
self.dump_path = os.path.abspath(dump_path)
107+
108+
workspace_path = (
109+
workspace_path
110+
if workspace_path is not None
111+
else os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
112+
)
113+
self.workspace_path = os.path.abspath(workspace_path)
114+
if not self.workspace_path:
115+
raise EnvironmentError(
116+
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
117+
)
118+
119+
def prepare_to_extract(self, model_dump_path):
120+
os.makedirs(model_dump_path, exist_ok=True)
121+
new_flags = {
122+
"FLAGS_logging_trunc_pir_py_code": 1,
123+
"FLAGS_logging_pir_py_code_int_tensor_element_limit": 64,
124+
"FLAGS_logging_pir_py_code_dir": model_dump_path,
125+
}
126+
old_flags = paddle.get_flags(list(new_flags.keys()))
127+
128+
print(f"Set pir dumping path to {model_dump_path}")
129+
paddle.set_flags(new_flags)
130+
return old_flags
131+
132+
def run_model_with_dump_enabled(self, model_dump_path, **input_dict):
133+
# Get model dump path
134+
old_flags = self.prepare_to_extract(model_dump_path)
135+
136+
param_dict = {
137+
k: v
138+
for k, v in input_dict.items()
139+
if isinstance(v, paddle.nn.parameter.Parameter)
140+
}
141+
data_dict = {k: v for k, v in input_dict.items() if k not in param_dict}
142+
143+
input_spec = self.input_spec
144+
if self.input_spec is None:
145+
input_spec = [
146+
paddle.static.InputSpec(value.shape, value.dtype, name=name)
147+
for name, value in data_dict.items()
148+
if isinstance(value, paddle.Tensor)
149+
]
150+
else:
151+
assert len(input_spec) == len(data_dict)
152+
153+
if param_dict:
154+
model_wrapper_class = generate_model_wrapper_class(
155+
model_dump_path, data_dict.keys()
156+
)
157+
wrapped_model = model_wrapper_class(self.model)
158+
wrapped_model.set_parameters(**param_dict)
159+
else:
160+
wrapped_model = self.model
161+
162+
# Run the static model
163+
static_model = paddle.jit.to_static(
164+
wrapped_model,
165+
input_spec=input_spec,
166+
full_graph=True,
167+
backend=None,
168+
)
169+
static_model.eval()
170+
program = static_model.forward.concrete_program.main_program
171+
# print(program)
172+
static_model(**data_dict)
173+
174+
# Restore the environment
175+
paddle.set_flags(old_flags)
176+
return static_model
177+
178+
def translate_pir_program_to_sample_codes(
179+
self, model_dump_path, split_positions=None
180+
):
181+
ir_programs_path = os.path.join(model_dump_path, "exec_programs.py")
182+
example_inputs_path = os.path.join(
183+
model_dump_path, "programs_example_input_tensor_meta.py"
184+
)
185+
assert os.path.isfile(
186+
ir_programs_path
187+
), f"{ir_programs_path} is not a regular file."
188+
assert os.path.isfile(
189+
example_inputs_path
190+
), f"{example_inputs_path} is not a regular file."
191+
192+
# Arguments for graph decomposer
193+
op_example_inputs_path = (
194+
os.path.join(model_dump_path, "op_example_input_tensor_meta.py")
195+
if split_positions
196+
else None
197+
)
198+
all_samples = RunGeneration(
199+
model_name=self.name,
200+
ir_programs=ir_programs_path,
201+
example_inputs=example_inputs_path,
202+
op_example_inputs=op_example_inputs_path,
203+
split_positions=split_positions,
204+
eval_mode=True,
205+
)
206+
207+
self.subgraph_idx2samples = {}
208+
for sample in all_samples:
209+
if sample.subgraph_idx not in self.subgraph_idx2samples.keys():
210+
self.subgraph_idx2samples[sample.subgraph_idx] = []
211+
self.subgraph_idx2samples[sample.subgraph_idx].append(sample)
212+
213+
self.num_subgraphs = len(self.subgraph_idx2samples)
214+
self.num_samples_of_all_subgraphs = len(all_samples)
215+
assert self.num_subgraphs > 0
216+
return self.subgraph_idx2samples
217+
218+
def write_sample_to_file(self, subgraph_path, sample):
219+
if not os.path.exists(subgraph_path):
220+
os.makedirs(subgraph_path, exist_ok=True)
221+
write_to_file(f"{subgraph_path}/model.py", sample.model)
222+
write_to_file(f"{subgraph_path}/weight_meta.py", sample.weight_meta)
223+
write_to_file(f"{subgraph_path}/input_meta.py", sample.input_meta)
224+
with open(os.path.join(subgraph_path, "graph_net.json"), "w") as f:
225+
json.dump(sample.metadata, f, indent=4)
226+
227+
def __call__(self, **input_dict):
228+
# 1. Run the model to dump pir programs
229+
model_dump_path = os.path.join(self.dump_path, self.name)
230+
static_model = self.run_model_with_dump_enabled(model_dump_path, **input_dict)
231+
232+
# 2. Convert pir programs to graphnet samples
233+
self.translate_pir_program_to_sample_codes(
234+
model_dump_path, split_positions=None
235+
)
236+
237+
# 3. Save to model_path
238+
model_path = os.path.join(self.workspace_path, self.name)
239+
for subgraph_idx, samples in self.subgraph_idx2samples.items():
240+
assert len(samples) == 1
241+
if self.num_samples_of_all_subgraphs == 1:
242+
subgraph_path = model_path
243+
else:
244+
subgraph_path = os.path.join(model_path, f"subgraph_{subgraph_idx}")
245+
self.write_sample_to_file(subgraph_path, samples[0])
246+
247+
print(
248+
f"Graph and tensors for '{self.name}' extracted successfully to: {model_path}"
249+
)
250+
return static_model
251+
252+
253+
def extract(name, dynamic=False, input_spec=None, extractor_config: dict = None):
254+
"""
255+
Extract computation graphs from PaddlePaddle nn.Layer.
256+
The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
257+
258+
Args:
259+
name (str): The name of the model, used as the directory name for saving.
260+
dynamic (bool): Enable dynamic shape support in paddle.jit.to_static.
261+
input_spec (list[InputSpec] | tuple[InputSpec]): InputSpec for input tensors, which includes tensor's name, shape and dtype.
262+
When dynamic is False, input_spec can be inferred automatically.
263+
264+
Returns:
265+
wrapper or decorator
266+
"""
267+
268+
extractor_config = make_extractor_config(extractor_config)
269+
270+
def get_graph_extractor_maker():
271+
custom_extractor_path = extractor_config["custom_extractor_path"]
272+
custom_extractor_config = extractor_config["custom_extractor_config"]
273+
if custom_extractor_path is None:
274+
return GraphExtractor
275+
276+
cls = load_class_from_file(custom_extractor_path, "GraphExtractor")
277+
return lambda *args, **kwargs: cls(custom_extractor_config, *args, **kwargs)
278+
279+
def wrapper(model: paddle.nn.Layer):
280+
assert isinstance(model, paddle.nn.Layer), f"{type(model)=}"
281+
extractor = get_graph_extractor_maker()(model, name, dynamic, input_spec)
282+
return extractor
283+
284+
def decorator(module_class):
285+
def constructor(*args, **kwargs):
286+
return wrapper(module_class(*args, **kwargs))
287+
288+
return constructor
289+
290+
def decorator_or_wrapper(obj):
291+
if isinstance(obj, paddle.nn.Layer):
292+
return wrapper(obj)
293+
elif issubclass(obj, paddle.nn.Layer):
294+
return decorator(obj)
295+
else:
296+
raise NotImplementedError(
297+
"Only paddle.nn.Layer instance or subclass supported."
298+
)
299+
300+
return decorator_or_wrapper
301+
302+
303+
def make_extractor_config(extractor_config):
304+
kwargs = extractor_config if extractor_config is not None else {}
305+
return make_extractor_config_impl(**kwargs)
306+
307+
308+
def make_extractor_config_impl(
309+
custom_extractor_path: str = None, custom_extractor_config: dict = None
310+
):
311+
config = custom_extractor_config if custom_extractor_config is not None else {}
312+
return {
313+
"custom_extractor_path": custom_extractor_path,
314+
"custom_extractor_config": config,
315+
}

0 commit comments

Comments
 (0)