We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 52c240a commit bb50392Copy full SHA for bb50392
art/estimators/object_detection/pytorch_object_detector.py
@@ -270,6 +270,12 @@ def _get_losses(
270
x_preprocessed = x_preprocessed.to(self.device)
271
y_preprocessed = [{k: v.to(self.device) for k, v in y_i.items()} for y_i in y_preprocessed]
272
273
+ # Set gradients again after inputs are moved to another device
274
+ if x_preprocessed.is_leaf:
275
+ x_preprocessed.requires_grad = True
276
+ else:
277
+ x_preprocessed.retain_grad()
278
+
279
loss_components = self._model(x_preprocessed, y_preprocessed)
280
281
return loss_components, x_preprocessed
0 commit comments