Skip to content

Commit bd66eed

Browse files
authored
Trainer save load params (#10386)
* Load/save the params from the params_path * Switch to use load_persistables and save_persistables * Instaed of setup the executor to run program and scope. Pass the program to the load_persistables
1 parent 5812076 commit bd66eed

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

python/paddle/fluid/inferencer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515
import core
16-
16+
import framework
17+
import executor
18+
import io
1719
__all__ = ['Inferencer', ]
1820

1921

@@ -29,6 +31,15 @@ def __init__(self, network_func, param_path=None, place=None):
2931
# 4. load params from param_path into scope
3032
self.scope = core.Scope()
3133
self.place = place
34+
self.startup_program = framework.Program()
35+
# TODO: generate the startup_program with network_func
36+
37+
exe = executor.Executor(place)
38+
exe.run(self.startup_program, scope=self.scope)
39+
40+
if param_path:
41+
# load params from param_path into scope
42+
io.load_persistables(exe, dirname=param_path)
3243

3344
def infer(self, inputs):
3445
# run self.program

python/paddle/fluid/trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import executor
1919
import data_feeder
2020
import contextlib
21+
import io
2122

2223
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
2324
import optimizer as opt_module
@@ -93,8 +94,7 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
9394

9495
if param_path:
9596
# load params from param_path into scope
96-
# TODO(yuyang): This depends on parameters implementation.
97-
pass
97+
io.load_persistables(exe, dirname=param_path)
9898

9999
def dist_transpile_if_necessary(self, optimize_ops, params_grads):
100100
if "PADDLE_TRAINING_ROLE" not in os.environ:
@@ -172,7 +172,9 @@ def test(self, reader):
172172

173173
def save_params(self, param_path):
174174
# reference: save_persistables in io.py
175-
pass
175+
exe = executor.Executor(self.place)
176+
io.save_persistables(
177+
exe, dirname=param_path, main_program=self.startup_program)
176178

177179
@staticmethod
178180
def _check_and_get_place(place):

0 commit comments

Comments
 (0)