Skip to content

Commit b5404c2

Browse files
committed
fix mixed precision eval when use COCOCallback
1 parent 955113b commit b5404c2

File tree

3 files changed

+3
-6
lines changed

3 files changed

+3
-6
lines changed

efficientdet/keras/efficientdet_keras.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -940,10 +940,6 @@ def _postprocess(self, cls_outputs, box_outputs, scales, mode='global'):
940940
if not mode:
941941
return cls_outputs, box_outputs
942942

943-
# TODO(tanmingxing): remove this cast once FP16 works postprocessing.
944-
cls_outputs = [tf.cast(i, tf.float32) for i in cls_outputs]
945-
box_outputs = [tf.cast(i, tf.float32) for i in box_outputs]
946-
947943
if mode == 'global':
948944
return postprocess.postprocess_global(self.config.as_dict(), cls_outputs,
949945
box_outputs, scales)

efficientdet/keras/postprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828

2929
def to_list(inputs):
3030
if isinstance(inputs, dict):
31-
return [inputs[k] for k in sorted(inputs.keys())]
31+
return [tf.cast(inputs[k], tf.float32) for k in sorted(inputs.keys())]
3232
if isinstance(inputs, list):
33-
return inputs
33+
return [tf.cast(i, tf.float32) for i in inputs]
3434
raise ValueError('Unrecognized inputs : {}'.format(inputs))
3535

3636

efficientdet/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ PyYAML>=5.1
77
six>=1.15.0
88
tensorflow>=2.4.0
99
tensorflow-addons>=0.12
10+
tensorflow-hub>=0.11
1011
neural-structured-learning>=1.3.1
1112
tensorflow-model-optimization>=0.5
1213
Cython>=0.29.13

0 commit comments

Comments
 (0)