We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bb50392 commit 4650875Copy full SHA for 4650875
art/estimators/object_detection/pytorch_yolo.py
@@ -358,6 +358,12 @@ def _get_losses(
358
x_preprocessed = x_preprocessed.to(self.device)
359
y_preprocessed_yolo = y_preprocessed_yolo.to(self.device)
360
361
+ # Set gradients again after inputs are moved to another device
362
+ if x_preprocessed.is_leaf:
363
+ x_preprocessed.requires_grad = True
364
+ else:
365
+ x_preprocessed.retain_grad()
366
+
367
# Calculate loss components
368
loss_components = self._model(x_preprocessed, y_preprocessed_yolo)
369
0 commit comments