Skip to content

Commit c6da8c4

Browse files
authored
Merge pull request #2238 from f4str/pytorch-object-detector-fix
FIx `PytorchObjectDetector` Loss Gradient Bug
2 parents c7f0a4a + 37628a6 commit c6da8c4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ def _preprocess_and_convert_inputs(
219219

220220
# Set gradients
221221
if not no_grad:
222-
x_tensor.requires_grad = True
222+
if x_tensor.is_leaf:
223+
x_tensor.requires_grad = True
224+
else:
225+
x_tensor.retain_grad()
223226

224227
# Apply framework-specific preprocessing
225228
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x_tensor, y=y_tensor, fit=fit, no_grad=no_grad)

0 commit comments

Comments
 (0)