1111from graph_net .paddle import utils
1212
1313
14+ # used as configuration of python -m graph_net.paddle.run_model
15+ class RunModelDecorator :
16+ def __init__ (self , config ):
17+ self .config = self .make_config (** config )
18+
19+ def __call__ (self , model ):
20+ return extract (** self .config )(model )
21+
22+ def make_config (
23+ self ,
24+ name = None ,
25+ dynamic = False ,
26+ input_spec = None ,
27+ custom_extractor_path : str = None ,
28+ custom_extractor_config : dict = None ,
29+ ):
30+ assert name is not None
31+ return {
32+ "name" : name ,
33+ "dynamic" : dynamic ,
34+ "input_spec" : input_spec ,
35+ "extractor_config" : {
36+ "custom_extractor_path" : custom_extractor_path ,
37+ "custom_extractor_config" : custom_extractor_config ,
38+ },
39+ }
40+
41+
1442class GraphExtractor :
1543 def __init__ (
1644 self ,
@@ -26,45 +54,39 @@ def __init__(
2654 self .input_spec = input_spec
2755 assert not self .dynamic , "dynamic=True is not supported now!"
2856
29- self .subgraph_counter = 0
30- self .dump_path = os .environ .get ("GRAPH_NET_PIR_DUMP_WORKSPACE" , "/tmp" )
31- self .workspace_path = (
57+ self .num_subgraphs = 0
58+ self .num_samples_of_all_subgraphs = 0
59+ self .subgraph_idx2samples = None
60+
61+ dump_path = os .environ .get ("GRAPH_NET_PIR_DUMP_WORKSPACE" , "/tmp" )
62+ self .dump_path = os .path .abspath (dump_path )
63+
64+ workspace_path = (
3265 workspace_path
3366 if workspace_path is not None
3467 else os .environ .get ("GRAPH_NET_EXTRACT_WORKSPACE" )
3568 )
69+ self .workspace_path = os .path .abspath (workspace_path )
3670 if not self .workspace_path :
3771 raise EnvironmentError (
3872 "Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
3973 )
4074
4175 def prepare_to_extract (self , model_dump_path ):
4276 os .makedirs (model_dump_path , exist_ok = True )
43- old_flags = paddle . get_flags (
44- [
45- "FLAGS_logging_trunc_pir_py_code" ,
46- "FLAGS_logging_pir_py_code_int_tensor_element_limit" ,
47- "FLAGS_logging_pir_py_code_dir" ,
48- ]
49- )
77+ new_flags = {
78+ "FLAGS_logging_trunc_pir_py_code" : 1 ,
79+ "FLAGS_logging_pir_py_code_int_tensor_element_limit" : 64 ,
80+ "FLAGS_logging_pir_py_code_dir" : model_dump_path ,
81+ }
82+ old_flags = paddle . get_flags ( list ( new_flags . keys ()))
83+
5084 print (f"Set pir dumping path to { model_dump_path } " )
51- paddle .set_flags (
52- {
53- "FLAGS_logging_trunc_pir_py_code" : 1 ,
54- "FLAGS_logging_pir_py_code_int_tensor_element_limit" : 64 ,
55- "FLAGS_logging_pir_py_code_dir" : model_dump_path ,
56- }
57- )
85+ paddle .set_flags (new_flags )
5886 return old_flags
5987
60- def write_to_file (self , filepath , content ):
61- print (f"Write to { filepath } " )
62- with open (filepath , "w" ) as f :
63- f .write (content )
64-
65- def __call__ (self , ** input_dict ):
66- # 1. Get model dump path
67- model_dump_path = os .path .join (self .dump_path , self .name )
88+ def run_model_with_dump_enabled (self , model_dump_path , ** input_dict ):
89+ # Get model dump path
6890 old_flags = self .prepare_to_extract (model_dump_path )
6991
7092 if self .input_spec is None :
@@ -74,13 +96,19 @@ def __call__(self, **input_dict):
7496 if isinstance (value , paddle .Tensor )
7597 ]
7698
77- # 2. Run the model to dump pir programs
99+ # Run the static model
78100 static_model = paddle .jit .to_static (
79101 self .model , input_spec = self .input_spec , full_graph = True
80102 )
81103 static_model (** input_dict )
82104
83- # 3. Convert pir programs to graphnet samples
105+ # Restore the environment
106+ paddle .set_flags (old_flags )
107+ return static_model
108+
109+ def translate_pir_program_to_sample_codes (
110+ self , model_dump_path , split_positions = None
111+ ):
84112 ir_programs_path = os .path .join (model_dump_path , "exec_programs.py" )
85113 example_inputs_path = os .path .join (
86114 model_dump_path , "programs_example_input_tensor_meta.py"
@@ -92,43 +120,108 @@ def __call__(self, **input_dict):
92120 example_inputs_path
93121 ), f"{ example_inputs_path } is not a regular file."
94122
123+ # Arguments for graph decomposer
124+ op_example_inputs_path = (
125+ os .path .join (model_dump_path , "op_example_input_tensor_meta.py" )
126+ if split_positions
127+ else None
128+ )
95129 graphnet_samples = generate_samples (
96130 model_name = self .name ,
97131 ir_programs = ir_programs_path ,
98132 example_inputs = example_inputs_path ,
133+ op_example_inputs = op_example_inputs_path ,
134+ split_positions = split_positions ,
99135 eval_mode = True ,
100136 )
101137
102- # 4. Save to model_path
138+ self .subgraph_idx2samples = {}
139+ for sample in graphnet_samples :
140+ if sample .subgraph_idx not in self .subgraph_idx2samples .keys ():
141+ self .subgraph_idx2samples [sample .subgraph_idx ] = []
142+ self .subgraph_idx2samples [sample .subgraph_idx ].append (sample )
143+
144+ self .num_subgraphs = len (self .subgraph_idx2samples )
145+ self .num_samples_of_all_subgraphs = len (graphnet_samples )
146+ assert self .num_subgraphs > 0
147+ return self .subgraph_idx2samples
148+
149+ def write_sample_to_file (self , subgraph_path , sample ):
150+ def write_to_file (filepath , content ):
151+ print (f"Write to { filepath } " )
152+ with open (filepath , "w" ) as f :
153+ f .write (content )
154+
155+ if not os .path .exists (subgraph_path ):
156+ os .makedirs (subgraph_path , exist_ok = True )
157+ write_to_file (f"{ subgraph_path } /model.py" , sample .model )
158+ write_to_file (f"{ subgraph_path } /weight_meta.py" , sample .weight_meta )
159+ write_to_file (f"{ subgraph_path } /input_meta.py" , sample .input_meta )
160+ with open (os .path .join (subgraph_path , "graph_net.json" ), "w" ) as f :
161+ json .dump (sample .metadata , f , indent = 4 )
162+
163+ def __call__ (self , ** input_dict ):
164+ # 1. Run the model to dump pir programs
165+ model_dump_path = os .path .join (self .dump_path , self .name )
166+ static_model = self .run_model_with_dump_enabled (model_dump_path , ** input_dict )
167+
168+ # 2. Convert pir programs to graphnet samples
169+ self .translate_pir_program_to_sample_codes (
170+ model_dump_path , split_positions = None
171+ )
172+
173+ # 3. Save to model_path
103174 model_path = os .path .join (self .workspace_path , self .name )
104- self .subgraph_counter = len (graphnet_samples )
105- for i , sample in enumerate (graphnet_samples ):
106- subgraph_path = (
107- model_path
108- if self .subgraph_counter == 1
109- else os .path .join (model_path , f"subgraph_{ i } " )
110- )
111- if not os .path .exists (subgraph_path ):
112- os .makedirs (subgraph_path , exist_ok = True )
113- self .write_to_file (f"{ subgraph_path } /model.py" , sample .model )
114- self .write_to_file (f"{ subgraph_path } /weight_meta.py" , sample .weight_meta )
115- self .write_to_file (f"{ subgraph_path } /input_meta.py" , sample .input_meta )
116- with open (os .path .join (subgraph_path , "graph_net.json" ), "w" ) as f :
117- json .dump (sample .metadata , f , indent = 4 )
175+ for subgraph_idx , samples in self .subgraph_idx2samples .items ():
176+ assert len (samples ) == 1
177+ if self .num_samples_of_all_subgraphs == 1 :
178+ subgraph_path = model_path
179+ else :
180+ subgraph_path = os .path .join (model_path , f"subgraph_{ subgraph_idx } " )
181+ self .write_sample_to_file (subgraph_path , samples [0 ])
118182
119183 print (
120184 f"Graph and tensors for '{ self .name } ' extracted successfully to: { model_path } "
121185 )
122-
123- # 5. Restore the environment
124- paddle .set_flags (old_flags )
125186 return static_model
126187
127188
128- def extract (name , dynamic = False , input_spec = None ):
189+ def extract (name , dynamic = False , input_spec = None , extractor_config : dict = None ):
190+ """
191+ Extract computation graphs from PaddlePaddle nn.Layer.
192+ The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
193+
194+ Args:
195+ name (str): The name of the model, used as the directory name for saving.
196+ dynamic (bool): Enable dynamic shape support in paddle.jit.to_static.
197+ input_spec (list[InputSpec] | tuple[InputSpec]): InputSpec for input tensors, which includes tensor's name, shape and dtype.
198+ When dynamic is False, input_spec can be inferred automatically.
199+
200+ Returns:
201+ wrapper or decorator
202+ """
203+
204+ extractor_config = make_extractor_config (extractor_config )
205+
206+ def get_graph_extractor_maker ():
207+ custom_extractor_path = extractor_config ["custom_extractor_path" ]
208+ custom_extractor_config = extractor_config ["custom_extractor_config" ]
209+ if custom_extractor_path is None :
210+ return GraphExtractor
211+ import importlib .util as imp
212+
213+ print (f"Import graph_extractor from { custom_extractor_path } " )
214+ # import custom_extractor_path as graph_extractor
215+ spec = imp .spec_from_file_location ("graph_extractor" , custom_extractor_path )
216+ graph_extractor = imp .module_from_spec (spec )
217+ spec .loader .exec_module (graph_extractor )
218+ cls = graph_extractor .GraphExtractor
219+ return lambda * args , ** kwargs : cls (custom_extractor_config , * args , ** kwargs )
220+
129221 def wrapper (model : paddle .nn .Layer ):
130222 assert isinstance (model , paddle .nn .Layer ), f"{ type (model )= } "
131- return GraphExtractor (model , name , dynamic , input_spec )
223+ extractor = get_graph_extractor_maker ()(model , name , dynamic , input_spec )
224+ return extractor
132225
133226 def decorator (module_class ):
134227 def constructor (* args , ** kwargs ):
@@ -147,3 +240,18 @@ def decorator_or_wrapper(obj):
147240 )
148241
149242 return decorator_or_wrapper
243+
244+
245+ def make_extractor_config (extractor_config ):
246+ kwargs = extractor_config if extractor_config is not None else {}
247+ return make_extractor_config_impl (** kwargs )
248+
249+
250+ def make_extractor_config_impl (
251+ custom_extractor_path : str = None , custom_extractor_config : dict = None
252+ ):
253+ config = custom_extractor_config if custom_extractor_config is not None else {}
254+ return {
255+ "custom_extractor_path" : custom_extractor_path ,
256+ "custom_extractor_config" : config ,
257+ }
0 commit comments