@@ -64,8 +64,8 @@ def _extract_vars(inputs, result_list, err_tag='inputs'):
64
64
_extract_vars (var , result_list , err_tag )
65
65
else :
66
66
raise TypeError (
67
- "The type of 'each element of {}' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}." .
68
- format (err_tag , type (inputs )))
67
+ "The type of 'each element of {}' in fluid.dygraph.jit.TracedLayer.trace must be fluid.Variable, but received {}."
68
+ . format (err_tag , type (inputs )))
69
69
70
70
71
71
def extract_vars (inputs , err_tag = 'inputs' ):
@@ -211,29 +211,28 @@ def decorated(python_func):
211
211
_ , python_func = unwrap_decorators (python_func )
212
212
213
213
# Step 2. copy some attributes from original python function.
214
- static_layer = copy_decorator_attrs (
215
- original_func = python_func ,
216
- decorated_obj = StaticFunction (
217
- function = python_func ,
218
- input_spec = input_spec ,
219
- build_strategy = build_strategy ))
214
+ static_layer = copy_decorator_attrs (original_func = python_func ,
215
+ decorated_obj = StaticFunction (
216
+ function = python_func ,
217
+ input_spec = input_spec ,
218
+ build_strategy = build_strategy ))
220
219
221
220
return static_layer
222
221
223
222
build_strategy = build_strategy or BuildStrategy ()
224
223
if not isinstance (build_strategy , BuildStrategy ):
225
224
raise TypeError (
226
- "Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {}" .
227
- format (type (build_strategy ).__name__ ))
225
+ "Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {}"
226
+ . format (type (build_strategy ).__name__ ))
228
227
229
228
# for usage: `declarative(foo, ...)`
230
229
if function is not None :
231
230
if isinstance (function , Layer ):
232
231
if isinstance (function .forward , StaticFunction ):
233
232
class_name = function .__class__ .__name__
234
233
logging_utils .warn (
235
- "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one." .
236
- format (class_name ))
234
+ "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one."
235
+ . format (class_name ))
237
236
function .forward = decorated (function .forward )
238
237
return function
239
238
else :
@@ -284,6 +283,7 @@ def func(x):
284
283
285
284
286
285
class _SaveLoadConfig (object ):
286
+
287
287
def __init__ (self ):
288
288
self ._output_spec = None
289
289
self ._model_filename = None
@@ -371,7 +371,7 @@ def keep_name_table(self, value):
371
371
372
372
373
373
def _parse_save_configs (configs ):
374
- supported_configs = ['output_spec' , "with_hook" ]
374
+ supported_configs = ['output_spec' , "with_hook" , "clip_extra" ]
375
375
376
376
# input check
377
377
for key in configs :
@@ -384,6 +384,7 @@ def _parse_save_configs(configs):
384
384
inner_config = _SaveLoadConfig ()
385
385
inner_config .output_spec = configs .get ('output_spec' , None )
386
386
inner_config .with_hook = configs .get ('with_hook' , False )
387
+ inner_config .clip_extra = configs .get ("clip_extra" , False )
387
388
388
389
return inner_config
389
390
@@ -622,6 +623,7 @@ def _remove_save_pre_hook(hook):
622
623
623
624
624
625
def _run_save_pre_hooks (func ):
626
+
625
627
def wrapper (layer , path , input_spec = None , ** configs ):
626
628
global _save_pre_hooks
627
629
for hook in _save_pre_hooks :
@@ -775,8 +777,8 @@ def fun(inputs):
775
777
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
776
778
)
777
779
778
- if not (isinstance (layer , Layer ) or inspect .isfunction (layer ) or isinstance (
779
- layer , StaticFunction )):
780
+ if not (isinstance (layer , Layer ) or inspect .isfunction (layer )
781
+ or isinstance ( layer , StaticFunction )):
780
782
raise TypeError (
781
783
"The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s."
782
784
% type (layer ))
@@ -837,7 +839,7 @@ def fun(inputs):
837
839
# parse configs
838
840
configs = _parse_save_configs (configs )
839
841
# whether outermost layer has pre/post hook, if does, we need also save
840
- # these operators in program.
842
+ # these operators in program.
841
843
with_hook = configs .with_hook
842
844
843
845
scope = core .Scope ()
@@ -848,7 +850,9 @@ def fun(inputs):
848
850
with_hook = True
849
851
else :
850
852
# layer is function
851
- functions = [layer , ]
853
+ functions = [
854
+ layer ,
855
+ ]
852
856
for attr_func in functions :
853
857
if isinstance (layer , Layer ):
854
858
static_func = getattr (inner_layer , attr_func , None )
@@ -862,8 +866,8 @@ def fun(inputs):
862
866
if inner_input_spec :
863
867
inner_input_spec = pack_sequence_as (input_spec ,
864
868
inner_input_spec )
865
- static_forward = declarative (
866
- inner_layer . forward , input_spec = inner_input_spec )
869
+ static_forward = declarative (inner_layer . forward ,
870
+ input_spec = inner_input_spec )
867
871
concrete_program = static_forward .concrete_program_specify_input_spec (
868
872
with_hook = with_hook )
869
873
# the input_spec has been used in declarative, which is equal to
@@ -882,14 +886,14 @@ def fun(inputs):
882
886
if inner_input_spec :
883
887
inner_input_spec = pack_sequence_as (input_spec ,
884
888
inner_input_spec )
885
- static_function = declarative (
886
- attr_func , input_spec = inner_input_spec )
889
+ static_function = declarative (attr_func ,
890
+ input_spec = inner_input_spec )
887
891
concrete_program = static_function .concrete_program
888
892
889
893
if static_function ._class_instance is None :
890
894
warnings .warn (
891
- '`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`' .
892
- format (layer ))
895
+ '`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'
896
+ . format (layer ))
893
897
894
898
dygraph_state_dict = None
895
899
if isinstance (inner_layer , Layer ):
@@ -922,8 +926,8 @@ def fun(inputs):
922
926
param_or_buffer_tensor = scope .var (
923
927
param_or_buffer .name ).get_tensor ()
924
928
#src_tensor = param_or_buffer.value().get_tensor()
925
- src_tensor = state_var_dict [param_or_buffer . name ]. value (
926
- ).get_tensor ()
929
+ src_tensor = state_var_dict [
930
+ param_or_buffer . name ]. value ( ).get_tensor ()
927
931
param_or_buffer_tensor ._share_data_with (src_tensor )
928
932
# record var info
929
933
if param_or_buffer .name not in extra_var_info :
@@ -986,7 +990,7 @@ def fun(inputs):
986
990
params_filename = params_filename ,
987
991
export_for_deployment = configs ._export_for_deployment ,
988
992
program_only = configs ._program_only ,
989
- clip_extra = False )
993
+ clip_extra = configs . clip_extra )
990
994
991
995
# NOTE(chenweihang): [ Save extra variable info ]
992
996
# save_inference_model will lose some important variable information, including:
@@ -1534,14 +1538,16 @@ def forward(self, input):
1534
1538
"fluid.dygraph.jit.TracedLayer.save_inference_model" )
1535
1539
if isinstance (feed , list ):
1536
1540
for f in feed :
1537
- check_type (f , "each element of feed" , int ,
1538
- "fluid.dygraph.jit.TracedLayer.save_inference_model" )
1541
+ check_type (
1542
+ f , "each element of feed" , int ,
1543
+ "fluid.dygraph.jit.TracedLayer.save_inference_model" )
1539
1544
check_type (fetch , "fetch" , (type (None ), list ),
1540
1545
"fluid.dygraph.jit.TracedLayer.save_inference_model" )
1541
1546
if isinstance (fetch , list ):
1542
1547
for f in fetch :
1543
- check_type (f , "each element of fetch" , int ,
1544
- "fluid.dygraph.jit.TracedLayer.save_inference_model" )
1548
+ check_type (
1549
+ f , "each element of fetch" , int ,
1550
+ "fluid.dygraph.jit.TracedLayer.save_inference_model" )
1545
1551
clip_extra = kwargs .get ('clip_extra' , False )
1546
1552
# path check
1547
1553
file_prefix = os .path .basename (path )
@@ -1575,12 +1581,11 @@ def get_feed_fetch(all_vars, partial_vars):
1575
1581
model_filename = file_prefix + INFER_MODEL_SUFFIX
1576
1582
params_filename = file_prefix + INFER_PARAMS_SUFFIX
1577
1583
1578
- save_inference_model (
1579
- dirname = dirname ,
1580
- feeded_var_names = feeded_var_names ,
1581
- target_vars = target_vars ,
1582
- executor = self ._exe ,
1583
- main_program = self ._program .clone (),
1584
- model_filename = model_filename ,
1585
- params_filename = params_filename ,
1586
- clip_extra = clip_extra )
1584
+ save_inference_model (dirname = dirname ,
1585
+ feeded_var_names = feeded_var_names ,
1586
+ target_vars = target_vars ,
1587
+ executor = self ._exe ,
1588
+ main_program = self ._program .clone (),
1589
+ model_filename = model_filename ,
1590
+ params_filename = params_filename ,
1591
+ clip_extra = clip_extra )
0 commit comments