Skip to content

Commit 4650875

Browse files
committed
fix pytorch yolo gradients bug
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent bb50392 commit 4650875

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,12 @@ def _get_losses(
358358
x_preprocessed = x_preprocessed.to(self.device)
359359
y_preprocessed_yolo = y_preprocessed_yolo.to(self.device)
360360

361+
# Set gradients again after inputs are moved to another device
362+
if x_preprocessed.is_leaf:
363+
x_preprocessed.requires_grad = True
364+
else:
365+
x_preprocessed.retain_grad()
366+
361367
# Calculate loss components
362368
loss_components = self._model(x_preprocessed, y_preprocessed_yolo)
363369

0 commit comments

Comments
 (0)