Skip to content

Commit ac32bf6

Browse files
committed
update input params type, test=develop
1 parent 09442fb commit ac32bf6

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

python/paddle/fluid/imperative/checkpoint.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717
import os
1818
import collections
1919
from .. import core
20-
from ..framework import Variable, Parameter, default_main_program
21-
from .layers import Layer
20+
from ..framework import Variable, default_main_program
2221

2322
__all__ = ['save_persistables', 'load_persistables']
2423

2524

26-
def save_persistables(obj, dirname, filename=None):
25+
def save_persistables(vardict, dirname, filename=None):
2726
"""
2827
This function filters out all variables in layer.parameters from the
2928
give `layer` and then trys to load these variables from the folder
@@ -35,7 +34,7 @@ def save_persistables(obj, dirname, filename=None):
3534
the file name.
3635
3736
Args:
38-
var_list(dict of Parameters|Layer): The parameters will
37+
vardict(dict of Parameters): The parameters will
3938
be saved. If it is None, nothing
4039
will be deal.
4140
dirname(str): The directory path.
@@ -69,17 +68,14 @@ def save_persistables(obj, dirname, filename=None):
6968
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
7069
init_cell)
7170
param_path = "./my_paddle_model"
72-
fluid.imperative.checkpoint.save_persistables(ptb_model.parameters(), dirname=param_path,
71+
fluid.imperative.checkpoint.save_persistables(ptb_model.state_dict(), dirname=param_path,
7372
layer=ptb_model)
7473
"""
75-
if isinstance(obj, collections.OrderedDict):
76-
_save_var_to_file(obj, dirname, filename)
77-
elif isinstance(obj, Layer):
78-
_save_var_to_file(
79-
obj.state_dict(include_sublayers=True), dirname, filename)
74+
if isinstance(vardict, collections.OrderedDict):
75+
_save_var_to_file(vardict, dirname, filename)
8076

8177

82-
def load_persistables(obj, dirname, filename=None):
78+
def load_persistables(vardict, dirname, filename=None):
8379
"""
8480
This function trys to load persistable variables from the folder
8581
`dirname` or the file `filename`.
@@ -90,7 +86,7 @@ def load_persistables(obj, dirname, filename=None):
9086
the file name.
9187
9288
Args:
93-
obj(dict of Parameters|Layer): The parameters will be loaded.
89+
vardict(dict of Parameters): The parameters will be loaded.
9490
dirname(str): The directory path.
9591
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
9692
saved in differnet files, set it to None.
@@ -111,16 +107,13 @@ def load_persistables(obj, dirname, filename=None):
111107
my_layer = layer(fluid.imperative.Layer)
112108
param_path = "./my_paddle_model"
113109
filename = "model.file"
114-
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer, var_list, param_path,
110+
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.state_dict(), param_path,
115111
filename=filename)
116112
param_1 = param_dict['PtbModel_0.w_1']
117113
118114
"""
119-
if isinstance(obj, collections.OrderedDict):
120-
return _load_var_from_file(obj, dirname, filename)
121-
elif isinstance(obj, Layer):
122-
return _load_var_from_file(
123-
obj.state_dict(include_sublayers=True), dirname, filename)
115+
if isinstance(vardict, collections.OrderedDict):
116+
return _load_var_from_file(vardict, dirname, filename)
124117

125118
return {}
126119

0 commit comments

Comments
 (0)