Skip to content

Commit 24ef040

Browse files
authored
fix tpu variable load (#940)
1 parent b5404c2 commit 24ef040

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

efficientdet/keras/util_keras.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def restore_ckpt(model,
111111
if tf.io.gfile.isdir(ckpt_path_or_file):
112112
ckpt_path_or_file = tf.train.latest_checkpoint(ckpt_path_or_file)
113113

114-
if (tf.train.list_variables(ckpt_path_or_file)[0][0] ==
115-
'_CHECKPOINTABLE_OBJECT_GRAPH'):
114+
var_shape_map = tf.train.load_checkpoint(ckpt_path_or_file).get_variable_to_shape_map()
115+
if '_CHECKPOINTABLE_OBJECT_GRAPH' in var_shape_map:
116116
model.load_weights(ckpt_path_or_file)
117117
else:
118118
if ema_decay > 0:
@@ -133,20 +133,23 @@ def restore_ckpt(model,
133133
# try to load graph-based checkpoint with ema support,
134134
# else load checkpoint via keras.load_weights which doesn't support ema.
135135
for i, (key, var) in enumerate(var_dict.items()):
136-
try:
137-
var.assign(tf.train.load_variable(ckpt_path_or_file, key))
138-
if i < 10:
139-
logging.info('Init %s from %s (%s)', var.name, key, ckpt_path_or_file)
140-
except tf.errors.NotFoundError as e:
141-
if skip_mismatch:
142-
logging.warning('Not found %s in %s', key, ckpt_path_or_file)
136+
if key in var_shape_map:
137+
if var_shape_map[key] != var.shape:
138+
msg = 'Shape mismatch: %s' % key
139+
if skip_mismatch:
140+
logging.warning(msg)
141+
else:
142+
raise ValueError(msg)
143143
else:
144-
raise e
145-
except ValueError as e:
144+
var.assign(tf.train.load_variable(ckpt_path_or_file, key))
145+
if i < 10:
146+
logging.info('Init %s from %s (%s)', var.name, key, ckpt_path_or_file)
147+
else:
148+
msg = 'Not found %s in %s' % (key, ckpt_path_or_file)
146149
if skip_mismatch:
147-
logging.warning('%s: %s', key, e)
150+
logging.warning(msg)
148151
else:
149-
raise e
152+
raise KeyError(msg)
150153

151154

152155
def fp16_to_fp32_nested(input_nested):

0 commit comments

Comments
 (0)