Skip to content

Commit c79edec

Browse files
authored
Fix mirrored variable load (#895)
* fix load variable bug when use mirrored strategy * add custom aspect ratio support * fix bug * optimize performance
1 parent e494379 commit c79edec

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

efficientdet/keras/anchors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ def _generate_boxes(self):
116116
stride, octave_scale, aspect, anchor_scale = config
117117
base_anchor_size_x = anchor_scale * stride[1] * 2**octave_scale
118118
base_anchor_size_y = anchor_scale * stride[0] * 2**octave_scale
119-
aspect_x = np.sqrt(aspect)
120-
aspect_y = 1.0 / aspect_x
119+
if isinstance(aspect, list):
120+
aspect_x, aspect_y = aspect
121+
else:
122+
aspect_x = np.sqrt(aspect)
123+
aspect_y = 1.0 / aspect_x
121124
anchor_size_x_2 = base_anchor_size_x * aspect_x / 2.0
122125
anchor_size_y_2 = base_anchor_size_y * aspect_y / 2.0
123126

efficientdet/utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,35 +111,47 @@ def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None):
111111
ckpt_var_name_to_shape = reader.get_variable_to_shape_map()
112112
ckpt_var_names = set(reader.get_variable_to_shape_map().keys())
113113

114+
if tf.distribute.get_replica_context():
115+
replica_id = tf.keras.backend.get_value(
116+
tf.distribute.get_replica_context().replica_id_in_sync_group)
117+
else:
118+
replica_id = 0
119+
114120
for i, v in enumerate(model_vars):
115-
if not v.op.name.startswith(var_scope):
121+
var_op_name = v.op.name
122+
123+
if replica_id >= 1:
124+
var_op_name = ''.join(var_op_name.rsplit(f'/replica_{replica_id}', 1))
125+
126+
if not var_op_name.startswith(var_scope):
116127
logging.info('skip {} -- does not match scope {}'.format(
117-
v.op.name, var_scope))
118-
ckpt_var = ckpt_scope + v.op.name[len(var_scope):]
128+
var_op_name, var_scope))
129+
ckpt_var = ckpt_scope + var_op_name[len(var_scope):]
130+
119131
if (ckpt_var not in ckpt_var_names and
120-
v.op.name.endswith('/ExponentialMovingAverage')):
121-
ckpt_var = ckpt_scope + v.op.name[:-len('/ExponentialMovingAverage')]
132+
var_op_name.endswith('/ExponentialMovingAverage')):
133+
ckpt_var = ckpt_scope + var_op_name[:-len('/ExponentialMovingAverage')]
122134

123135
if ckpt_var not in ckpt_var_names:
124136
if 'Momentum' in ckpt_var or 'RMSProp' in ckpt_var:
125137
# Skip optimizer variables.
126138
continue
127139
if skip_mismatch:
128-
logging.info('skip {} ({}) -- not in ckpt'.format(v.op.name, ckpt_var))
140+
logging.info('skip {} ({}) -- not in ckpt'.format(var_op_name, ckpt_var))
129141
continue
130142
raise ValueError('{} is not in ckpt {}'.format(v.op, ckpt_path))
131143

132144
if v.shape != ckpt_var_name_to_shape[ckpt_var]:
133145
if skip_mismatch:
134146
logging.info('skip {} ({} vs {}) -- shape mismatch'.format(
135-
v.op.name, v.shape, ckpt_var_name_to_shape[ckpt_var]))
147+
var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var]))
136148
continue
137149
raise ValueError('shape mismatch {} ({} vs {})'.format(
138-
v.op.name, v.shape, ckpt_var_name_to_shape[ckpt_var]))
150+
var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var]))
139151

140152
if i < 5:
141153
# Log the first few elements for sanity check.
142-
logging.info('Init {} from ckpt var {}'.format(v.op.name, ckpt_var))
154+
logging.info('Init {} from ckpt var {}'.format(var_op_name, ckpt_var))
143155
var_map[ckpt_var] = v
144156

145157
return var_map

0 commit comments

Comments
 (0)