Skip to content

Commit bb50392

Browse files
committed
fix pytorch object detector gradients bug
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 52c240a commit bb50392

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ def _get_losses(
270270
x_preprocessed = x_preprocessed.to(self.device)
271271
y_preprocessed = [{k: v.to(self.device) for k, v in y_i.items()} for y_i in y_preprocessed]
272272

273+
# Set gradients again after inputs are moved to another device
274+
if x_preprocessed.is_leaf:
275+
x_preprocessed.requires_grad = True
276+
else:
277+
x_preprocessed.retain_grad()
278+
273279
loss_components = self._model(x_preprocessed, y_preprocessed)
274280

275281
return loss_components, x_preprocessed

0 commit comments

Comments
 (0)