Skip to content

Commit ea5672b

Browse files
committed
Generate GraphModuleWrapper class to support re-extract from a GraphNet sample and retain the parameter information.
1 parent 2da5267 commit ea5672b

File tree

2 files changed

+92
-18
lines changed

2 files changed

+92
-18
lines changed

graph_net/paddle/extractor.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,61 @@
11
import os
22
import json
3-
4-
os.environ["ENABLE_CINN_IN_DY2ST"] = "0"
5-
# os.environ["FLAGS_logging_trunc_pir_py_code"] = "1"
6-
# os.environ["FLAGS_logging_pir_py_code_int_tensor_element_limit"] = "64"
7-
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
3+
import importlib.util
84

95
import paddle
106
from athena.module_op_unittests_for_graphnet import GraphnetSample, generate_samples
117
from graph_net.paddle import utils
128

139

10+
def load_class_from_file(file_path: str, class_name: str):
11+
print(f"Load {class_name} from {file_path}")
12+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
13+
unnamed = importlib.util.module_from_spec(spec)
14+
spec.loader.exec_module(unnamed)
15+
model_class = getattr(unnamed, class_name, None)
16+
return model_class
17+
18+
19+
def write_to_file(filepath, content):
20+
print(f"Write to {filepath}")
21+
with open(filepath, "w") as f:
22+
f.write(content)
23+
24+
25+
def generate_model_wrapper_class(model_dump_path, data_arg_names):
26+
graph_module_wrapper_class_template = """
27+
import paddle
28+
29+
class GraphModuleWrapper(paddle.nn.Layer):
30+
def __init__(self, graph_module):
31+
super().__init__()
32+
self.graph_module = graph_module
33+
34+
def set_parameters(self, **kwargs):
35+
for name, value in kwargs.items():
36+
if isinstance(value, paddle.nn.parameter.Parameter):
37+
setattr(self, name, value)
38+
39+
def forward(self, ${DATA_ARG_NAMES}):
40+
param_dict = { name: param for name, param in self.named_parameters() }
41+
outputs = self.graph_module(${DATA_ARG_VALUE_PAIRS}, **param_dict)
42+
return outputs
43+
"""
44+
45+
data_arg_value_pairs = [f"{name}={name}" for name in data_arg_names]
46+
graph_module_wrapper_class_code_str = graph_module_wrapper_class_template.replace(
47+
"${DATA_ARG_NAMES}", ", ".join(data_arg_names)
48+
).replace("${DATA_ARG_VALUE_PAIRS}", ", ".join(data_arg_value_pairs))
49+
print(graph_module_wrapper_class_code_str)
50+
51+
file_path = os.path.join(model_dump_path, "graph_module_wrapper.py")
52+
write_to_file(file_path, graph_module_wrapper_class_code_str)
53+
model_class = load_class_from_file(
54+
file_path=file_path, class_name="GraphModuleWrapper"
55+
)
56+
return model_class
57+
58+
1459
# used as configuration of python -m graph_net.paddle.run_model
1560
class RunModelDecorator:
1661
def __init__(self, config):
@@ -89,18 +134,47 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict):
89134
# Get model dump path
90135
old_flags = self.prepare_to_extract(model_dump_path)
91136

137+
param_dict = {
138+
k: v
139+
for k, v in input_dict.items()
140+
if isinstance(v, paddle.nn.parameter.Parameter)
141+
}
142+
data_dict = {
143+
k: v
144+
for k, v in input_dict.items()
145+
if not isinstance(v, paddle.nn.parameter.Parameter)
146+
}
147+
148+
input_spec = self.input_spec
92149
if self.input_spec is None:
93-
self.input_spec = [
150+
input_spec = [
94151
paddle.static.InputSpec(value.shape, value.dtype, name=name)
95-
for name, value in input_dict.items()
152+
for name, value in data_dict.items()
96153
if isinstance(value, paddle.Tensor)
97154
]
155+
else:
156+
assert len(input_spec) == len(data_dict)
157+
158+
if param_dict:
159+
model_wrapper_class = generate_model_wrapper_class(
160+
model_dump_path, data_dict.keys()
161+
)
162+
wrapped_model = model_wrapper_class(self.model)
163+
wrapped_model.set_parameters(**param_dict)
164+
else:
165+
wrapped_model = self.model
98166

99167
# Run the static model
100168
static_model = paddle.jit.to_static(
101-
self.model, input_spec=self.input_spec, full_graph=True
169+
wrapped_model,
170+
input_spec=input_spec,
171+
full_graph=True,
172+
backend=None,
102173
)
103-
static_model(**input_dict)
174+
static_model.eval()
175+
program = static_model.forward.concrete_program.main_program
176+
# print(program)
177+
static_model(**data_dict)
104178

105179
# Restore the environment
106180
paddle.set_flags(old_flags)
@@ -147,11 +221,6 @@ def translate_pir_program_to_sample_codes(
147221
return self.subgraph_idx2samples
148222

149223
def write_sample_to_file(self, subgraph_path, sample):
150-
def write_to_file(filepath, content):
151-
print(f"Write to {filepath}")
152-
with open(filepath, "w") as f:
153-
f.write(content)
154-
155224
if not os.path.exists(subgraph_path):
156225
os.makedirs(subgraph_path, exist_ok=True)
157226
write_to_file(f"{subgraph_path}/model.py", sample.model)

graph_net/paddle/run_model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import os
12
import sys
23
import json
34
import base64
45
import argparse
56
import importlib.util
67
from typing import Type
78

9+
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
10+
811
import paddle
912
from graph_net.paddle import utils
1013

@@ -23,8 +26,11 @@ def get_input_dict(model_path):
2326
params = inputs_params["weight_info"]
2427
inputs = inputs_params["input_info"]
2528

26-
params.update(inputs)
27-
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
29+
state_dict = {}
30+
for k, v in params.items():
31+
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=k)
32+
for k, v in inputs.items():
33+
state_dict[k] = utils.replay_tensor(v)
2834
return state_dict
2935

3036

@@ -58,9 +64,8 @@ def main(args):
5864
model = model_class()
5965
print(f"{model_path=}")
6066

61-
model = _get_decorator(args)(model)
6267
input_dict = get_input_dict(args.model_path)
63-
68+
model = _get_decorator(args)(model)
6469
model(**input_dict)
6570

6671

0 commit comments

Comments
 (0)