Skip to content

Commit 3b5e3f9

Browse files
committed
update checkpoint unittest
1 parent 951fa74 commit 3b5e3f9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python/paddle/fluid/tests/unittests/test_checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import paddle.fluid as fluid
1616
import unittest
17+
import os
1718

1819

1920
class TestCheckpoint(unittest.TestCase):
@@ -35,15 +36,16 @@ def test_checkpoint(self):
3536
trainer_args = ["epoch_id", "step_id"]
3637
epoch_id, step_id = fluid.io.load_trainer_args(
3738
self.dirname, serial, self.trainer_id, trainer_args)
38-
self.assertEqual(self.step_id, step_id)
39-
self.assertEqual(self.epoch_id, epoch_id)
39+
self.assertEqual(self.step_id, int(step_id))
40+
self.assertEqual(self.epoch_id, int(epoch_id))
4041

4142
program = fluid.Program()
4243
with fluid.program_guard(program):
4344
exe = fluid.Executor(self.place)
4445
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
4546

4647
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
48+
self.assertFalse(os.path.isdir(self.dirname))
4749

4850
def save_checkpoint(self):
4951
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,

0 commit comments

Comments
 (0)