Skip to content

Commit ef5bb1f

Browse files
committed
Generate GraphModuleWrapper class to support re-extract from a GraphNet sample and retain the parameter information.
1 parent 2da5267 commit ef5bb1f

File tree

2 files changed

+96
-35
lines changed

2 files changed

+96
-35
lines changed

graph_net/paddle/extractor.py

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,60 @@
11
import os
22
import json
3-
4-
os.environ["ENABLE_CINN_IN_DY2ST"] = "0"
5-
# os.environ["FLAGS_logging_trunc_pir_py_code"] = "1"
6-
# os.environ["FLAGS_logging_pir_py_code_int_tensor_element_limit"] = "64"
7-
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
3+
import importlib.util
84

95
import paddle
10-
from athena.module_op_unittests_for_graphnet import GraphnetSample, generate_samples
6+
from athena.graphnet_samples import GraphnetSample, RunGeneration
7+
from graph_net import imp_util
118
from graph_net.paddle import utils
129

1310

11+
def load_class_from_file(file_path: str, class_name: str):
12+
print(f"Load {class_name} from {file_path}")
13+
module = imp_util.load_module(file_path, "unnamed")
14+
model_class = getattr(module, class_name, None)
15+
return model_class
16+
17+
18+
def write_to_file(filepath, content):
19+
print(f"Write to {filepath}")
20+
with open(filepath, "w") as f:
21+
f.write(content)
22+
23+
24+
def generate_model_wrapper_class(model_dump_path, data_arg_names):
25+
graph_module_wrapper_class_template = """
26+
import paddle
27+
28+
class GraphModuleWrapper(paddle.nn.Layer):
29+
def __init__(self, graph_module):
30+
super().__init__()
31+
self.graph_module = graph_module
32+
33+
def set_parameters(self, **kwargs):
34+
for name, value in kwargs.items():
35+
if isinstance(value, paddle.nn.parameter.Parameter):
36+
setattr(self, name, value)
37+
38+
def forward(self, ${DATA_ARG_NAMES}):
39+
param_dict = { name: param for name, param in self.named_parameters() }
40+
outputs = self.graph_module(${DATA_ARG_VALUE_PAIRS}, **param_dict)
41+
return outputs
42+
"""
43+
44+
data_arg_value_pairs = [f"{name}={name}" for name in data_arg_names]
45+
graph_module_wrapper_class_code_str = graph_module_wrapper_class_template.replace(
46+
"${DATA_ARG_NAMES}", ", ".join(data_arg_names)
47+
).replace("${DATA_ARG_VALUE_PAIRS}", ", ".join(data_arg_value_pairs))
48+
print(graph_module_wrapper_class_code_str)
49+
50+
file_path = os.path.join(model_dump_path, "graph_module_wrapper.py")
51+
write_to_file(file_path, graph_module_wrapper_class_code_str)
52+
model_class = load_class_from_file(
53+
file_path=file_path, class_name="GraphModuleWrapper"
54+
)
55+
return model_class
56+
57+
1458
# used as configuration of python -m graph_net.paddle.run_model
1559
class RunModelDecorator:
1660
def __init__(self, config):
@@ -89,18 +133,43 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict):
89133
# Get model dump path
90134
old_flags = self.prepare_to_extract(model_dump_path)
91135

136+
param_dict = {
137+
k: v
138+
for k, v in input_dict.items()
139+
if isinstance(v, paddle.nn.parameter.Parameter)
140+
}
141+
data_dict = {k: v for k, v in input_dict.items() if k not in param_dict}
142+
143+
input_spec = self.input_spec
92144
if self.input_spec is None:
93-
self.input_spec = [
145+
input_spec = [
94146
paddle.static.InputSpec(value.shape, value.dtype, name=name)
95-
for name, value in input_dict.items()
147+
for name, value in data_dict.items()
96148
if isinstance(value, paddle.Tensor)
97149
]
150+
else:
151+
assert len(input_spec) == len(data_dict)
152+
153+
if param_dict:
154+
model_wrapper_class = generate_model_wrapper_class(
155+
model_dump_path, data_dict.keys()
156+
)
157+
wrapped_model = model_wrapper_class(self.model)
158+
wrapped_model.set_parameters(**param_dict)
159+
else:
160+
wrapped_model = self.model
98161

99162
# Run the static model
100163
static_model = paddle.jit.to_static(
101-
self.model, input_spec=self.input_spec, full_graph=True
164+
wrapped_model,
165+
input_spec=input_spec,
166+
full_graph=True,
167+
backend=None,
102168
)
103-
static_model(**input_dict)
169+
static_model.eval()
170+
program = static_model.forward.concrete_program.main_program
171+
# print(program)
172+
static_model(**data_dict)
104173

105174
# Restore the environment
106175
paddle.set_flags(old_flags)
@@ -126,7 +195,7 @@ def translate_pir_program_to_sample_codes(
126195
if split_positions
127196
else None
128197
)
129-
graphnet_samples = generate_samples(
198+
all_samples = RunGeneration(
130199
model_name=self.name,
131200
ir_programs=ir_programs_path,
132201
example_inputs=example_inputs_path,
@@ -136,22 +205,17 @@ def translate_pir_program_to_sample_codes(
136205
)
137206

138207
self.subgraph_idx2samples = {}
139-
for sample in graphnet_samples:
208+
for sample in all_samples:
140209
if sample.subgraph_idx not in self.subgraph_idx2samples.keys():
141210
self.subgraph_idx2samples[sample.subgraph_idx] = []
142211
self.subgraph_idx2samples[sample.subgraph_idx].append(sample)
143212

144213
self.num_subgraphs = len(self.subgraph_idx2samples)
145-
self.num_samples_of_all_subgraphs = len(graphnet_samples)
214+
self.num_samples_of_all_subgraphs = len(all_samples)
146215
assert self.num_subgraphs > 0
147216
return self.subgraph_idx2samples
148217

149218
def write_sample_to_file(self, subgraph_path, sample):
150-
def write_to_file(filepath, content):
151-
print(f"Write to {filepath}")
152-
with open(filepath, "w") as f:
153-
f.write(content)
154-
155219
if not os.path.exists(subgraph_path):
156220
os.makedirs(subgraph_path, exist_ok=True)
157221
write_to_file(f"{subgraph_path}/model.py", sample.model)
@@ -208,14 +272,8 @@ def get_graph_extractor_maker():
208272
custom_extractor_config = extractor_config["custom_extractor_config"]
209273
if custom_extractor_path is None:
210274
return GraphExtractor
211-
import importlib.util as imp
212-
213-
print(f"Import graph_extractor from {custom_extractor_path}")
214-
# import custom_extractor_path as graph_extractor
215-
spec = imp.spec_from_file_location("graph_extractor", custom_extractor_path)
216-
graph_extractor = imp.module_from_spec(spec)
217-
spec.loader.exec_module(graph_extractor)
218-
cls = graph_extractor.GraphExtractor
275+
276+
cls = load_class_from_file(custom_extractor_path, "GraphExtractor")
219277
return lambda *args, **kwargs: cls(custom_extractor_config, *args, **kwargs)
220278

221279
def wrapper(model: paddle.nn.Layer):

graph_net/paddle/run_model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1+
import os
12
import sys
23
import json
34
import base64
45
import argparse
5-
import importlib.util
66
from typing import Type
77

8+
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
9+
810
import paddle
11+
from graph_net import imp_util
912
from graph_net.paddle import utils
1013

1114

1215
def load_class_from_file(file_path: str, class_name: str):
1316
print(f"Load {class_name} from {file_path}")
14-
spec = importlib.util.spec_from_file_location("unnamed", file_path)
15-
unnamed = importlib.util.module_from_spec(spec)
16-
spec.loader.exec_module(unnamed)
17-
model_class = getattr(unnamed, class_name, None)
17+
module = imp_util.load_module(file_path, "unnamed")
18+
model_class = getattr(module, class_name, None)
1819
return model_class
1920

2021

@@ -23,8 +24,11 @@ def get_input_dict(model_path):
2324
params = inputs_params["weight_info"]
2425
inputs = inputs_params["input_info"]
2526

26-
params.update(inputs)
27-
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
27+
state_dict = {}
28+
for k, v in params.items():
29+
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=k)
30+
for k, v in inputs.items():
31+
state_dict[k] = utils.replay_tensor(v)
2832
return state_dict
2933

3034

@@ -58,9 +62,8 @@ def main(args):
5862
model = model_class()
5963
print(f"{model_path=}")
6064

61-
model = _get_decorator(args)(model)
6265
input_dict = get_input_dict(args.model_path)
63-
66+
model = _get_decorator(args)(model)
6467
model(**input_dict)
6568

6669

0 commit comments

Comments
 (0)