Skip to content

Commit c25abb5

Browse files
authored
Merge pull request #1633 from Trusted-AI/development_issue_1630
Add support for multi-channel images in PyTorchObjectDetector
2 parents 66825ec + e12a3ca commit c25abb5

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

art/estimators/object_detection/python_object_detector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
maximum values allowed for features. If floats are provided, these will be used as the range of all
7575
features. If arrays are provided, each value will be considered the bound for a feature, thus
7676
the shape of clip values needs to match the total number of features.
77-
:param channels_first: Set channels first or last.
77+
:param channels_first: [Currently unused] Set channels first or last.
7878
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
7979
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
8080
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
@@ -214,7 +214,7 @@ def _get_losses(
214214
x_grad.requires_grad = True
215215
else:
216216
x_grad = x[i].to(self.device)
217-
if x_grad.shape[-1] in [1, 3]:
217+
if x_grad.shape[2] < x_grad.shape[0] and x_grad.shape[2] < x_grad.shape[1]:
218218
x_grad = torch.permute(x_grad, (2, 0, 1))
219219

220220
image_tensor_list_grad.append(x_grad)

art/estimators/object_detection/pytorch_faster_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
maximum values allowed for features. If floats are provided, these will be used as the range of all
7171
features. If arrays are provided, each value will be considered the bound for a feature, thus
7272
the shape of clip values needs to match the total number of features.
73-
:param channels_first: Set channels first or last.
73+
:param channels_first: [Currently unused] Set channels first or last.
7474
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
7575
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
7676
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be

0 commit comments

Comments
 (0)