Skip to content

Commit 39d85bf

Browse files
authored
Enable load program state in imperative mode (#24998) (#25441)
* enable load_program_state run in imperative mode; test=develop * remove useless code; test=develop
1 parent 6e1d0ef commit 39d85bf

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

python/paddle/fluid/io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def _load_program_scope(main=None, startup=None, scope=None):
191191
with paddle.fluid.scope_guard(scope):
192192
with paddle.fluid.program_guard(prog, startup_prog):
193193
with paddle.fluid.unique_name.guard():
194-
yield
194+
with paddle.fluid.framework._dygraph_guard(None):
195+
yield
195196

196197

197198
def _get_valid_program(main_program):

python/paddle/fluid/tests/unittests/test_static_save_load.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,6 @@ def test_ptb_rnn_cpu_float32(self):
11531153
# make sure all the paramerter or optimizer var have been set to zero
11541154
self.assertTrue(np.sum(np.abs(new_t)) == 0)
11551155

1156-
#fluid.load(test_program, "./test_1", None )
11571156
program_state = fluid.load_program_state("test_program_1")
11581157
fluid.set_program_state(main_program, program_state)
11591158

@@ -1164,6 +1163,11 @@ def test_ptb_rnn_cpu_float32(self):
11641163
base_t = base_map[var.name]
11651164
self.assertTrue(np.array_equal(new_t, base_t))
11661165

1166+
with fluid.dygraph.guard(place):
1167+
load_state = fluid.load_program_state("test_program_1")
1168+
for k, v in load_state.items():
1169+
self.assertTrue(np.array_equal(base_map[k], v))
1170+
11671171

11681172
class TestProgramStateOldSaveSingleModel(unittest.TestCase):
11691173
def test_ptb_rnn_cpu_float32(self):

0 commit comments

Comments
 (0)