Skip to content

Commit 0a01458

Browse files
authored
Merge pull request #2249 from f4str/pytorch-frcnn-loss-fix
Fix PyTorch Object Detection Estimators Missing Gradients Bug
2 parents 52c240a + 4650875 commit 0a01458

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ def _get_losses(
270270
x_preprocessed = x_preprocessed.to(self.device)
271271
y_preprocessed = [{k: v.to(self.device) for k, v in y_i.items()} for y_i in y_preprocessed]
272272

273+
# Set gradients again after inputs are moved to another device
274+
if x_preprocessed.is_leaf:
275+
x_preprocessed.requires_grad = True
276+
else:
277+
x_preprocessed.retain_grad()
278+
273279
loss_components = self._model(x_preprocessed, y_preprocessed)
274280

275281
return loss_components, x_preprocessed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,12 @@ def _get_losses(
358358
x_preprocessed = x_preprocessed.to(self.device)
359359
y_preprocessed_yolo = y_preprocessed_yolo.to(self.device)
360360

361+
# Set gradients again after inputs are moved to another device
362+
if x_preprocessed.is_leaf:
363+
x_preprocessed.requires_grad = True
364+
else:
365+
x_preprocessed.retain_grad()
366+
361367
# Calculate loss components
362368
loss_components = self._model(x_preprocessed, y_preprocessed_yolo)
363369

0 commit comments

Comments
 (0)