Skip to content

Commit c6df4a1

Browse files
author
Beat Buesser
committed
Fix batch-norm layers in PyTorchYolo._get_losses
Signed-off-by: Beat Buesser <[email protected]>
1 parent 89bf92f commit c6df4a1

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ def _get_losses(
253253
import torch # lgtm [py/repeated-import]
254254

255255
self._model.train()
256+
self.set_batchnorm(train=False)
257+
self.set_dropout(train=False)
256258

257259
# Apply preprocessing
258260
if self.all_framework_preprocessing:

0 commit comments

Comments
 (0)