Skip to content

Commit cbc8425

Browse files
committed
Support channels_first for upsampling.
Fix #1075 Note: TF has no native way to support upsampling for "channels_first" format, so here we use tf.transpose, which may (and often) be slow.
1 parent 2de72e3 commit cbc8425

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

efficientdet/keras/efficientdet_keras.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,15 @@ def _pool2d(self, inputs, height, width, target_height, target_width):
290290
raise ValueError('Unsupported pooling type {}.'.format(self.pooling_type))
291291

292292
def _upsample2d(self, inputs, target_height, target_width):
293-
return tf.cast(
293+
if self.data_format == 'channels_first':
294+
inputs = tf.compat.v1.transpose(inputs, perm=[0, 2, 3, 1])
295+
outputs = tf.cast(
294296
tf.compat.v1.image.resize_nearest_neighbor(
295297
tf.cast(inputs, tf.float32), [target_height, target_width]),
296298
inputs.dtype)
299+
if self.data_format == 'channels_first':
300+
outputs = tf.compat.v1.transpose(outputs, perm=[0, 3, 1, 2])
301+
return outputs
297302

298303
def _maybe_apply_1x1(self, feat, training, num_channels):
299304
"""Apply 1x1 conv to change layer width if necessary."""

0 commit comments

Comments
 (0)