Skip to content

Commit 9cbb599

Browse files
committed
fix pytorch object detector loss gradient bug
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 3585836 commit 9cbb599

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ def _preprocess_and_convert_inputs(
219219

220220
# Set gradients
221221
if not no_grad:
222-
x_tensor.requires_grad = True
222+
if x_tensor.is_leaf:
223+
x_tensor.requires_grad = True
224+
else:
225+
x_tensor.retain_grad()
223226

224227
# Apply framework-specific preprocessing
225228
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x_tensor, y=y_tensor, fit=fit, no_grad=no_grad)

0 commit comments

Comments
 (0)