1111from graph_net .paddle import utils
1212
1313
14+ # used as configuration of python -m graph_net.paddle.run_model
15+ class RunModelDecorator :
16+ def __init__ (self , config ):
17+ self .config = self .make_config (** config )
18+
19+ def __call__ (self , model ):
20+ return extract (** self .config )(model )
21+
22+ def make_config (
23+ self ,
24+ name = None ,
25+ dynamic = True ,
26+ input_spec = None ,
27+ custom_extractor_path : str = None ,
28+ custom_extractor_config : dict = None ,
29+ ):
30+ assert name is not None
31+ return {
32+ "name" : name ,
33+ "dynamic" : dynamic ,
34+ "input_spec" : input_spec ,
35+ "extractor_config" : {
36+ "custom_extractor_path" : custom_extractor_path ,
37+ "custom_extractor_config" : custom_extractor_config ,
38+ },
39+ }
40+
41+
1442class GraphExtractor :
1543 def __init__ (
1644 self ,
@@ -125,10 +153,42 @@ def __call__(self, **input_dict):
125153 return static_model
126154
127155
128- def extract (name , dynamic = False , input_spec = None ):
156+ def extract (name , dynamic = False , input_spec = None , extractor_config : dict = None ):
157+ """
158+ Extract computation graphs from PaddlePaddle nn.Layer.
159+ The extracted computation graphs will be saved into directory of env var $GRAPH_NET_EXTRACT_WORKSPACE.
160+
161+ Args:
162+ name (str): The name of the model, used as the directory name for saving.
163+ dynamic (bool): Enable dynamic shape support in paddle.jit.to_static.
164+ input_spec (list[InputSpec] | tuple[InputSpec]): InputSpec for input tensors, which includes tensor's name, shape and dtype.
165+ When dynamic is False, input_spec can be inferred automatically.
166+
167+ Returns:
168+ wrapper or decorator
169+ """
170+
171+ extractor_config = make_extractor_config (extractor_config )
172+
173+ def get_graph_extractor_maker ():
174+ custom_extractor_path = extractor_config ["custom_extractor_path" ]
175+ custom_extractor_config = extractor_config ["custom_extractor_config" ]
176+ if custom_extractor_path is None :
177+ return GraphExtractor
178+ import importlib .util as imp
179+
180+ print (f"Import graph_extractor from { custom_extractor_path } " )
181+ # import custom_extractor_path as graph_extractor
182+ spec = imp .spec_from_file_location ("graph_extractor" , custom_extractor_path )
183+ graph_extractor = imp .module_from_spec (spec )
184+ spec .loader .exec_module (graph_extractor )
185+ cls = graph_extractor .GraphExtractor
186+ return lambda * args , ** kwargs : cls (custom_extractor_config , * args , ** kwargs )
187+
129188 def wrapper (model : paddle .nn .Layer ):
130189 assert isinstance (model , paddle .nn .Layer ), f"{ type (model )= } "
131- return GraphExtractor (model , name , dynamic , input_spec )
190+ extractor = get_graph_extractor_maker ()(model , name , dynamic , input_spec )
191+ return extractor
132192
133193 def decorator (module_class ):
134194 def constructor (* args , ** kwargs ):
@@ -147,3 +207,18 @@ def decorator_or_wrapper(obj):
147207 )
148208
149209 return decorator_or_wrapper
210+
211+
212+ def make_extractor_config (extractor_config ):
213+ kwargs = extractor_config if extractor_config is not None else {}
214+ return make_extractor_config_impl (** kwargs )
215+
216+
217+ def make_extractor_config_impl (
218+ custom_extractor_path : str = None , custom_extractor_config : dict = None
219+ ):
220+ config = custom_extractor_config if custom_extractor_config is not None else {}
221+ return {
222+ "custom_extractor_path" : custom_extractor_path ,
223+ "custom_extractor_config" : config ,
224+ }
0 commit comments