11import os
22import json
3-
4- os .environ ["ENABLE_CINN_IN_DY2ST" ] = "0"
5- # os.environ["FLAGS_logging_trunc_pir_py_code"] = "1"
6- # os.environ["FLAGS_logging_pir_py_code_int_tensor_element_limit"] = "64"
7- os .environ ["FLAGS_logging_pir_py_code_dir" ] = "/tmp/dump"
3+ import importlib .util
84
95import paddle
10- from athena .module_op_unittests_for_graphnet import GraphnetSample , generate_samples
6+ from athena .graphnet_samples import GraphnetSample , RunGeneration
7+ from graph_net import imp_util
118from graph_net .paddle import utils
129
1310
11+ def load_class_from_file (file_path : str , class_name : str ):
12+ print (f"Load { class_name } from { file_path } " )
13+ module = imp_util .load_module (file_path , "unnamed" )
14+ model_class = getattr (module , class_name , None )
15+ return model_class
16+
17+
18+ def write_to_file (filepath , content ):
19+ print (f"Write to { filepath } " )
20+ with open (filepath , "w" ) as f :
21+ f .write (content )
22+
23+
24+ def generate_model_wrapper_class (model_dump_path , data_arg_names ):
25+ graph_module_wrapper_class_template = """
26+ import paddle
27+
28+ class GraphModuleWrapper(paddle.nn.Layer):
29+ def __init__(self, graph_module):
30+ super().__init__()
31+ self.graph_module = graph_module
32+
33+ def set_parameters(self, **kwargs):
34+ for name, value in kwargs.items():
35+ if isinstance(value, paddle.nn.parameter.Parameter):
36+ setattr(self, name, value)
37+
38+ def forward(self, ${DATA_ARG_NAMES}):
39+ param_dict = { name: param for name, param in self.named_parameters() }
40+ outputs = self.graph_module(${DATA_ARG_VALUE_PAIRS}, **param_dict)
41+ return outputs
42+ """
43+
44+ data_arg_value_pairs = [f"{ name } ={ name } " for name in data_arg_names ]
45+ graph_module_wrapper_class_code_str = graph_module_wrapper_class_template .replace (
46+ "${DATA_ARG_NAMES}" , ", " .join (data_arg_names )
47+ ).replace ("${DATA_ARG_VALUE_PAIRS}" , ", " .join (data_arg_value_pairs ))
48+ print (graph_module_wrapper_class_code_str )
49+
50+ file_path = os .path .join (model_dump_path , "graph_module_wrapper.py" )
51+ write_to_file (file_path , graph_module_wrapper_class_code_str )
52+ model_class = load_class_from_file (
53+ file_path = file_path , class_name = "GraphModuleWrapper"
54+ )
55+ return model_class
56+
57+
1458# used as configuration of python -m graph_net.paddle.run_model
1559class RunModelDecorator :
1660 def __init__ (self , config ):
@@ -89,18 +133,43 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict):
89133 # Get model dump path
90134 old_flags = self .prepare_to_extract (model_dump_path )
91135
136+ param_dict = {
137+ k : v
138+ for k , v in input_dict .items ()
139+ if isinstance (v , paddle .nn .parameter .Parameter )
140+ }
141+ data_dict = {k : v for k , v in input_dict .items () if k not in param_dict }
142+
143+ input_spec = self .input_spec
92144 if self .input_spec is None :
93- self . input_spec = [
145+ input_spec = [
94146 paddle .static .InputSpec (value .shape , value .dtype , name = name )
95- for name , value in input_dict .items ()
147+ for name , value in data_dict .items ()
96148 if isinstance (value , paddle .Tensor )
97149 ]
150+ else :
151+ assert len (input_spec ) == len (data_dict )
152+
153+ if param_dict :
154+ model_wrapper_class = generate_model_wrapper_class (
155+ model_dump_path , data_dict .keys ()
156+ )
157+ wrapped_model = model_wrapper_class (self .model )
158+ wrapped_model .set_parameters (** param_dict )
159+ else :
160+ wrapped_model = self .model
98161
99162 # Run the static model
100163 static_model = paddle .jit .to_static (
101- self .model , input_spec = self .input_spec , full_graph = True
164+ wrapped_model ,
165+ input_spec = input_spec ,
166+ full_graph = True ,
167+ backend = None ,
102168 )
103- static_model (** input_dict )
169+ static_model .eval ()
170+ program = static_model .forward .concrete_program .main_program
171+ # print(program)
172+ static_model (** data_dict )
104173
105174 # Restore the environment
106175 paddle .set_flags (old_flags )
@@ -126,7 +195,7 @@ def translate_pir_program_to_sample_codes(
126195 if split_positions
127196 else None
128197 )
129- graphnet_samples = generate_samples (
198+ all_samples = RunGeneration (
130199 model_name = self .name ,
131200 ir_programs = ir_programs_path ,
132201 example_inputs = example_inputs_path ,
@@ -136,22 +205,17 @@ def translate_pir_program_to_sample_codes(
136205 )
137206
138207 self .subgraph_idx2samples = {}
139- for sample in graphnet_samples :
208+ for sample in all_samples :
140209 if sample .subgraph_idx not in self .subgraph_idx2samples .keys ():
141210 self .subgraph_idx2samples [sample .subgraph_idx ] = []
142211 self .subgraph_idx2samples [sample .subgraph_idx ].append (sample )
143212
144213 self .num_subgraphs = len (self .subgraph_idx2samples )
145- self .num_samples_of_all_subgraphs = len (graphnet_samples )
214+ self .num_samples_of_all_subgraphs = len (all_samples )
146215 assert self .num_subgraphs > 0
147216 return self .subgraph_idx2samples
148217
149218 def write_sample_to_file (self , subgraph_path , sample ):
150- def write_to_file (filepath , content ):
151- print (f"Write to { filepath } " )
152- with open (filepath , "w" ) as f :
153- f .write (content )
154-
155219 if not os .path .exists (subgraph_path ):
156220 os .makedirs (subgraph_path , exist_ok = True )
157221 write_to_file (f"{ subgraph_path } /model.py" , sample .model )
@@ -208,14 +272,8 @@ def get_graph_extractor_maker():
208272 custom_extractor_config = extractor_config ["custom_extractor_config" ]
209273 if custom_extractor_path is None :
210274 return GraphExtractor
211- import importlib .util as imp
212-
213- print (f"Import graph_extractor from { custom_extractor_path } " )
214- # import custom_extractor_path as graph_extractor
215- spec = imp .spec_from_file_location ("graph_extractor" , custom_extractor_path )
216- graph_extractor = imp .module_from_spec (spec )
217- spec .loader .exec_module (graph_extractor )
218- cls = graph_extractor .GraphExtractor
275+
276+ cls = load_class_from_file (custom_extractor_path , "GraphExtractor" )
219277 return lambda * args , ** kwargs : cls (custom_extractor_config , * args , ** kwargs )
220278
221279 def wrapper (model : paddle .nn .Layer ):
0 commit comments