@@ -111,8 +111,8 @@ def restore_ckpt(model,
111111 if tf .io .gfile .isdir (ckpt_path_or_file ):
112112 ckpt_path_or_file = tf .train .latest_checkpoint (ckpt_path_or_file )
113113
114- if ( tf .train .list_variables (ckpt_path_or_file )[ 0 ][ 0 ] ==
115- '_CHECKPOINTABLE_OBJECT_GRAPH' ) :
114+ var_shape_map = tf .train .load_checkpoint (ckpt_path_or_file ). get_variable_to_shape_map ()
115+ if '_CHECKPOINTABLE_OBJECT_GRAPH' in var_shape_map :
116116 model .load_weights (ckpt_path_or_file )
117117 else :
118118 if ema_decay > 0 :
@@ -133,20 +133,23 @@ def restore_ckpt(model,
133133 # try to load graph-based checkpoint with ema support,
134134 # else load checkpoint via keras.load_weights which doesn't support ema.
135135 for i , (key , var ) in enumerate (var_dict .items ()):
136- try :
137- var . assign ( tf . train . load_variable ( ckpt_path_or_file , key ))
138- if i < 10 :
139- logging . info ( 'Init %s from %s (%s)' , var . name , key , ckpt_path_or_file )
140- except tf . errors . NotFoundError as e :
141- if skip_mismatch :
142- logging . warning ( 'Not found %s in %s' , key , ckpt_path_or_file )
136+ if key in var_shape_map :
137+ if var_shape_map [ key ] != var . shape :
138+ msg = 'Shape mismatch: %s' % key
139+ if skip_mismatch :
140+ logging . warning ( msg )
141+ else :
142+ raise ValueError ( msg )
143143 else :
144- raise e
145- except ValueError as e :
144+ var .assign (tf .train .load_variable (ckpt_path_or_file , key ))
145+ if i < 10 :
146+ logging .info ('Init %s from %s (%s)' , var .name , key , ckpt_path_or_file )
147+ else :
148+ msg = 'Not found %s in %s' % (key , ckpt_path_or_file )
146149 if skip_mismatch :
147- logging .warning ('%s: %s' , key , e )
150+ logging .warning (msg )
148151 else :
149- raise e
152+ raise KeyError ( msg )
150153
151154
152155def fp16_to_fp32_nested (input_nested ):
0 commit comments