@@ -483,11 +483,11 @@ def save_checkpoint(executor,
483
483
:param main_program
484
484
:param max_num_checkpoints
485
485
"""
486
- if checkpoint_dir is None :
487
- raise ValueError ("The values of 'checkpoint_dir' should not be None" )
486
+ if checkpoint_dir . strip () is None :
487
+ raise ValueError ("'checkpoint_dir' should not be None" )
488
488
489
- if trainer_args and not isinstance ( trainer_args , dict ) :
490
- raise TypeError ( "The type of ' trainer_args' should be dict" )
489
+ if trainer_args :
490
+ assert isinstance ( trainer_args , dict )
491
491
492
492
if not os .path .isdir (checkpoint_dir ):
493
493
os .makedirs (checkpoint_dir )
@@ -514,11 +514,11 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
514
514
:param main_program
515
515
"""
516
516
517
- if checkpoint_dir is None :
518
- raise ValueError ("The values of 'checkpoint_dir' should not be None" )
517
+ if checkpoint_dir . strip () is None :
518
+ raise ValueError ("'checkpoint_dir' should not be None" )
519
519
520
520
if serial is None or serial < 0 :
521
- raise ValueError ("The values of 'serial' should not be None or <0 " )
521
+ raise ValueError ("'serial' should not be None or <0 " )
522
522
523
523
if main_program is None :
524
524
raise ValueError ('main_program should not be None.' )
@@ -536,8 +536,8 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
536
536
:param delete_dir
537
537
"""
538
538
539
- if checkpoint_dir is None :
540
- raise ValueError ("The values of 'checkpoint_dir' should not be None" )
539
+ if checkpoint_dir . strip () is None :
540
+ raise ValueError ("'checkpoint_dir' should not be None" )
541
541
_lru_delete (checkpoint_dir , max_num_checkpoints = 0 )
542
542
543
543
if delete_dir and not os .listdir (checkpoint_dir ):
@@ -590,8 +590,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
590
590
591
591
592
592
def save_trainer_args (dirname , trainer_id , trainer_args ):
593
- if not isinstance (trainer_args , dict ):
594
- raise TypeError ( "The type of 'trainer_args' should be dict" )
593
+ assert isinstance (trainer_args , dict )
594
+
595
595
cur_dir = _get_trainer_dir (dirname , trainer_id )
596
596
597
597
for name , value in trainer_args .iteritems ():
@@ -602,12 +602,11 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
602
602
603
603
604
604
def load_trainer_args (checkpoint_dir , serial , trainer_id , trainer_args ):
605
+ assert isinstance (trainer_args , list )
606
+
605
607
cur_dir = _get_serial_dir (checkpoint_dir , serial )
606
608
cur_dir = _get_trainer_dir (cur_dir , trainer_id )
607
609
608
- if not isinstance (trainer_args , list ):
609
- raise TypeError ("The type of 'trainer_args' should be list" )
610
-
611
610
ret_values = []
612
611
613
612
for arg in trainer_args :
0 commit comments