Skip to content

Commit 676b996

Browse files
authored
Merge pull request #1471 from Trusted-AI/development_issue_1468
Replace x with x_preprocessed in PyTorchClassifier.get_activations
2 parents 94112c5 + 437f5ad commit 676b996

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

art/estimators/classification/pytorch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,11 @@ def get_activations(
834834
self._model.eval()
835835

836836
# Apply defences
837-
x_preprocessed, _ = self._apply_preprocessing(x=x, y=None, fit=False)
837+
if framework:
838+
no_grad = False
839+
else:
840+
no_grad = True
841+
x_preprocessed, _ = self._apply_preprocessing(x=x, y=None, fit=False, no_grad=no_grad)
838842

839843
# Get index of the extracted layer
840844
if isinstance(layer, six.string_types):
@@ -849,9 +853,9 @@ def get_activations(
849853
raise TypeError("Layer must be of type str or int")
850854

851855
if framework:
852-
if isinstance(x, torch.Tensor):
853-
return self._model(x)[layer_index]
854-
return self._model(torch.from_numpy(x).to(self._device))[layer_index]
856+
if isinstance(x_preprocessed, torch.Tensor):
857+
return self._model(x_preprocessed)[layer_index]
858+
return self._model(torch.from_numpy(x_preprocessed).to(self._device))[layer_index]
855859

856860
# Run prediction with batch processing
857861
results = []

0 commit comments

Comments
 (0)