Skip to content

Commit f28f41d

Browse files
committed
update io.py annotations and codes
1 parent 6db240d commit f28f41d

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

python/paddle/fluid/io.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,11 @@ def save_checkpoint(executor,
483483
:param main_program
484484
:param max_num_checkpoints
485485
"""
486-
if checkpoint_dir is None:
487-
raise ValueError("The values of 'checkpoint_dir' should not be None")
486+
if checkpoint_dir.strip() is None:
487+
raise ValueError("'checkpoint_dir' should not be None")
488488

489-
if trainer_args and not isinstance(trainer_args, dict):
490-
raise TypeError("The type of 'trainer_args' should be dict")
489+
if trainer_args:
490+
assert isinstance(trainer_args, dict)
491491

492492
if not os.path.isdir(checkpoint_dir):
493493
os.makedirs(checkpoint_dir)
@@ -514,11 +514,11 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
514514
:param main_program
515515
"""
516516

517-
if checkpoint_dir is None:
518-
raise ValueError("The values of 'checkpoint_dir' should not be None")
517+
if checkpoint_dir.strip() is None:
518+
raise ValueError("'checkpoint_dir' should not be None")
519519

520520
if serial is None or serial < 0:
521-
raise ValueError("The values of 'serial' should not be None or <0 ")
521+
raise ValueError("'serial' should not be None or <0 ")
522522

523523
if main_program is None:
524524
raise ValueError('main_program should not be None.')
@@ -536,8 +536,8 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
536536
:param delete_dir
537537
"""
538538

539-
if checkpoint_dir is None:
540-
raise ValueError("The values of 'checkpoint_dir' should not be None")
539+
if checkpoint_dir.strip() is None:
540+
raise ValueError("'checkpoint_dir' should not be None")
541541
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
542542

543543
if delete_dir and not os.listdir(checkpoint_dir):
@@ -590,8 +590,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
590590

591591

592592
def save_trainer_args(dirname, trainer_id, trainer_args):
593-
if not isinstance(trainer_args, dict):
594-
raise TypeError("The type of 'trainer_args' should be dict")
593+
assert isinstance(trainer_args, dict)
594+
595595
cur_dir = _get_trainer_dir(dirname, trainer_id)
596596

597597
for name, value in trainer_args.iteritems():
@@ -602,12 +602,11 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
602602

603603

604604
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
605+
assert isinstance(trainer_args, list)
606+
605607
cur_dir = _get_serial_dir(checkpoint_dir, serial)
606608
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
607609

608-
if not isinstance(trainer_args, list):
609-
raise TypeError("The type of 'trainer_args' should be list")
610-
611610
ret_values = []
612611

613612
for arg in trainer_args:

0 commit comments

Comments
 (0)