File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -111,7 +111,8 @@ def restore_ckpt(model,
111111 if tf .io .gfile .isdir (ckpt_path_or_file ):
112112 ckpt_path_or_file = tf .train .latest_checkpoint (ckpt_path_or_file )
113113
114- var_shape_map = tf .train .load_checkpoint (ckpt_path_or_file ).get_variable_to_shape_map ()
114+ reader = tf .train .load_checkpoint (ckpt_path_or_file )
115+ var_shape_map = reader .get_variable_to_shape_map ()
115116 if '_CHECKPOINTABLE_OBJECT_GRAPH' in var_shape_map :
116117 model .load_weights (ckpt_path_or_file )
117118 else :
@@ -141,7 +142,7 @@ def restore_ckpt(model,
141142 else :
142143 raise ValueError (msg )
143144 else :
144- var .assign (tf . train . load_variable ( ckpt_path_or_file , key ))
145+ var .assign (reader . get_tensor ( key ), read_value = False )
145146 if i < 10 :
146147 logging .info ('Init %s from %s (%s)' , var .name , key , ckpt_path_or_file )
147148 else :
You can’t perform that action at this time.
0 commit comments