File tree Expand file tree Collapse file tree 2 files changed +12
-0
lines changed
art/estimators/object_detection Expand file tree Collapse file tree 2 files changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -270,6 +270,12 @@ def _get_losses(
270270 x_preprocessed = x_preprocessed .to (self .device )
271271 y_preprocessed = [{k : v .to (self .device ) for k , v in y_i .items ()} for y_i in y_preprocessed ]
272272
273+ # Set gradients again after inputs are moved to another device
274+ if x_preprocessed .is_leaf :
275+ x_preprocessed .requires_grad = True
276+ else :
277+ x_preprocessed .retain_grad ()
278+
273279 loss_components = self ._model (x_preprocessed , y_preprocessed )
274280
275281 return loss_components , x_preprocessed
Original file line number Diff line number Diff line change @@ -358,6 +358,12 @@ def _get_losses(
358358 x_preprocessed = x_preprocessed .to (self .device )
359359 y_preprocessed_yolo = y_preprocessed_yolo .to (self .device )
360360
361+ # Set gradients again after inputs are moved to another device
362+ if x_preprocessed .is_leaf :
363+ x_preprocessed .requires_grad = True
364+ else :
365+ x_preprocessed .retain_grad ()
366+
361367 # Calculate loss components
362368 loss_components = self ._model (x_preprocessed , y_preprocessed_yolo )
363369
You can’t perform that action at this time.
0 commit comments