1111from graph_net .paddle import utils
1212
1313
14+ # used as configuration of python -m graph_net.paddle.run_model
15+ class RunModelDecorator :
16+ def __init__ (self , config ):
17+ self .config = self .make_config (** config )
18+
19+ def __call__ (self , model ):
20+ return extract (** self .config )(model )
21+
22+ def make_config (
23+ self ,
24+ name = None ,
25+ dynamic = False ,
26+ input_spec = None ,
27+ custom_extractor_path : str = None ,
28+ custom_extractor_config : dict = None ,
29+ ):
30+ assert name is not None
31+ return {
32+ "name" : name ,
33+ "dynamic" : dynamic ,
34+ "input_spec" : input_spec ,
35+ "extractor_config" : {
36+ "custom_extractor_path" : custom_extractor_path ,
37+ "custom_extractor_config" : custom_extractor_config ,
38+ },
39+ }
40+
41+
1442class GraphExtractor :
1543 def __init__ (
1644 self ,
@@ -26,7 +54,10 @@ def __init__(
2654 self .input_spec = input_spec
2755 assert not self .dynamic , "dynamic=True is not supported now!"
2856
29- self .subgraph_counter = 0
57+ self .num_subgraphs = 0
58+ self .num_samples_of_all_subgraphs = 0
59+ self .subgraph_idx2samples = None
60+
3061 self .dump_path = os .environ .get ("GRAPH_NET_PIR_DUMP_WORKSPACE" , "/tmp" )
3162 self .workspace_path = (
3263 workspace_path
@@ -57,30 +88,23 @@ def prepare_to_extract(self, model_dump_path):
5788 )
5889 return old_flags
5990
60- def write_to_file (self , filepath , content ):
61- print (f"Write to { filepath } " )
62- with open (filepath , "w" ) as f :
63- f .write (content )
64-
65- def __call__ (self , ** input_dict ):
66- # 1. Get model dump path
67- model_dump_path = os .path .join (self .dump_path , self .name )
68- old_flags = self .prepare_to_extract (model_dump_path )
69-
91+ def run_model (self , ** input_dict ):
7092 if self .input_spec is None :
7193 self .input_spec = [
7294 paddle .static .InputSpec (value .shape , value .dtype , name = name )
7395 for name , value in input_dict .items ()
7496 if isinstance (value , paddle .Tensor )
7597 ]
7698
77- # 2. Run the model to dump pir programs
7899 static_model = paddle .jit .to_static (
79100 self .model , input_spec = self .input_spec , full_graph = True
80101 )
81102 static_model (** input_dict )
103+ return static_model
82104
83- # 3. Convert pir programs to graphnet samples
105+ def translate_pir_program_to_sample_codes (
106+ self , model_dump_path , split_positions = None
107+ ):
84108 ir_programs_path = os .path .join (model_dump_path , "exec_programs.py" )
85109 example_inputs_path = os .path .join (
86110 model_dump_path , "programs_example_input_tensor_meta.py"
@@ -92,29 +116,73 @@ def __call__(self, **input_dict):
92116 example_inputs_path
93117 ), f"{ example_inputs_path } is not a regular file."
94118
119+ # Arguments for graph decomposer
120+ op_example_inputs_path = (
121+ os .path .join (model_dump_path , "op_example_input_tensor_meta.py" )
122+ if split_positions
123+ else None
124+ )
125+ split_positions = (
126+ "," .join (map (str , split_positions ))
127+ if split_positions and isinstance (split_positions , (tuple , list ))
128+ else split_positions
129+ )
130+
95131 graphnet_samples = generate_samples (
96132 model_name = self .name ,
97133 ir_programs = ir_programs_path ,
98134 example_inputs = example_inputs_path ,
135+ op_example_inputs = op_example_inputs_path ,
136+ split_positions = split_positions ,
99137 eval_mode = True ,
100138 )
101139
140+ self .subgraph_idx2samples = {}
141+ for sample in graphnet_samples :
142+ if sample .subgraph_idx not in self .subgraph_idx2samples .keys ():
143+ self .subgraph_idx2samples [sample .subgraph_idx ] = []
144+ self .subgraph_idx2samples [sample .subgraph_idx ].append (sample )
145+
146+ self .num_subgraphs = len (self .subgraph_idx2samples )
147+ self .num_samples_of_all_subgraphs = len (graphnet_samples )
148+ return self .subgraph_idx2samples
149+
150+ def write_sample_to_file (self , subgraph_path , sample ):
151+ def write_to_file (filepath , content ):
152+ print (f"Write to { filepath } " )
153+ with open (filepath , "w" ) as f :
154+ f .write (content )
155+
156+ if not os .path .exists (subgraph_path ):
157+ os .makedirs (subgraph_path , exist_ok = True )
158+ write_to_file (f"{ subgraph_path } /model.py" , sample .model )
159+ write_to_file (f"{ subgraph_path } /weight_meta.py" , sample .weight_meta )
160+ write_to_file (f"{ subgraph_path } /input_meta.py" , sample .input_meta )
161+ with open (os .path .join (subgraph_path , "graph_net.json" ), "w" ) as f :
162+ json .dump (sample .metadata , f , indent = 4 )
163+
164+ def __call__ (self , ** input_dict ):
165+ # 1. Get model dump path
166+ model_dump_path = os .path .join (self .dump_path , self .name )
167+ old_flags = self .prepare_to_extract (model_dump_path )
168+
169+ # 2. Run the model to dump pir programs
170+ static_model = self .run_model (** input_dict )
171+
172+ # 3. Convert pir programs to graphnet samples
173+ self .translate_pir_program_to_sample_codes (
174+ model_dump_path , split_positions = None
175+ )
176+
102177 # 4. Save to model_path
103178 model_path = os .path .join (self .workspace_path , self .name )
104- self .subgraph_counter = len (graphnet_samples )
105- for i , sample in enumerate (graphnet_samples ):
106- subgraph_path = (
107- model_path
108- if self .subgraph_counter == 1
109- else os .path .join (model_path , f"subgraph_{ i } " )
110- )
111- if not os .path .exists (subgraph_path ):
112- os .makedirs (subgraph_path , exist_ok = True )
113- self .write_to_file (f"{ subgraph_path } /model.py" , sample .model )
114- self .write_to_file (f"{ subgraph_path } /weight_meta.py" , sample .weight_meta )
115- self .write_to_file (f"{ subgraph_path } /input_meta.py" , sample .input_meta )
116- with open (os .path .join (subgraph_path , "graph_net.json" ), "w" ) as f :
117- json .dump (sample .metadata , f , indent = 4 )
179+ for subgraph_idx , samples in self .subgraph_idx2samples .items ():
180+ assert len (samples ) == 1
181+ if self .num_samples_of_all_subgraphs == 1 :
182+ subgraph_path = model_path
183+ else :
184+ subgraph_path = os .path .join (model_path , f"subgraph_{ subgraph_idx } " )
185+ self .write_sample_to_file (subgraph_path , samples [0 ])
118186
119187 print (
120188 f"Graph and tensors for '{ self .name } ' extracted successfully to: { model_path } "
@@ -125,10 +193,42 @@ def __call__(self, **input_dict):
125193 return static_model
126194
127195
128- def extract (name , dynamic = False , input_spec = None ):
196+ def extract (name , dynamic = False , input_spec = None , extractor_config : dict = None ):
197+ """
198+ Extract computation graphs from PaddlePaddle nn.Layer.
199+ The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
200+
201+ Args:
202+ name (str): The name of the model, used as the directory name for saving.
203+ dynamic (bool): Enable dynamic shape support in paddle.jit.to_static.
204+ input_spec (list[InputSpec] | tuple[InputSpec]): InputSpec for input tensors, which includes tensor's name, shape and dtype.
205+ When dynamic is False, input_spec can be inferred automatically.
206+
207+ Returns:
208+ wrapper or decorator
209+ """
210+
211+ extractor_config = make_extractor_config (extractor_config )
212+
213+ def get_graph_extractor_maker ():
214+ custom_extractor_path = extractor_config ["custom_extractor_path" ]
215+ custom_extractor_config = extractor_config ["custom_extractor_config" ]
216+ if custom_extractor_path is None :
217+ return GraphExtractor
218+ import importlib .util as imp
219+
220+ print (f"Import graph_extractor from { custom_extractor_path } " )
221+ # import custom_extractor_path as graph_extractor
222+ spec = imp .spec_from_file_location ("graph_extractor" , custom_extractor_path )
223+ graph_extractor = imp .module_from_spec (spec )
224+ spec .loader .exec_module (graph_extractor )
225+ cls = graph_extractor .GraphExtractor
226+ return lambda * args , ** kwargs : cls (custom_extractor_config , * args , ** kwargs )
227+
129228 def wrapper (model : paddle .nn .Layer ):
130229 assert isinstance (model , paddle .nn .Layer ), f"{ type (model )= } "
131- return GraphExtractor (model , name , dynamic , input_spec )
230+ extractor = get_graph_extractor_maker ()(model , name , dynamic , input_spec )
231+ return extractor
132232
133233 def decorator (module_class ):
134234 def constructor (* args , ** kwargs ):
@@ -147,3 +247,18 @@ def decorator_or_wrapper(obj):
147247 )
148248
149249 return decorator_or_wrapper
250+
251+
252+ def make_extractor_config (extractor_config ):
253+ kwargs = extractor_config if extractor_config is not None else {}
254+ return make_extractor_config_impl (** kwargs )
255+
256+
257+ def make_extractor_config_impl (
258+ custom_extractor_path : str = None , custom_extractor_config : dict = None
259+ ):
260+ config = custom_extractor_config if custom_extractor_config is not None else {}
261+ return {
262+ "custom_extractor_path" : custom_extractor_path ,
263+ "custom_extractor_config" : config ,
264+ }
0 commit comments