Skip to content

Commit 3da17e3

Browse files
author
Beat Buesser
committed
Update get_activations
Signed-off-by: Beat Buesser <[email protected]>
1 parent f0bf961 commit 3da17e3

File tree

3 files changed

+76
-65
lines changed

3 files changed

+76
-65
lines changed

art/attacks/poisoning/bullseye_polytope_attack.py

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ class BullseyePolytopeAttackPyTorch(PoisoningAttackWhiteBox):
5252
attack_params = PoisoningAttackWhiteBox.attack_params + [
5353
"target",
5454
"feature_layer",
55-
"opt"
56-
"max_iter",
55+
"opt" "max_iter",
5756
"learning_rate",
5857
"momentum",
5958
"decay_iter",
@@ -68,21 +67,21 @@ class BullseyePolytopeAttackPyTorch(PoisoningAttackWhiteBox):
6867
_estimator_requirements = (BaseEstimator, NeuralNetworkMixin, ClassifierMixin, PyTorchClassifier)
6968

7069
def __init__(
71-
self,
72-
classifier: Union["CLASSIFIER_NEURALNETWORK_TYPE", List["CLASSIFIER_NEURALNETWORK_TYPE"]],
73-
target: np.ndarray,
74-
feature_layer: Union[Union[str, int], List[Union[str, int]]],
75-
opt: str = 'adam',
76-
max_iter: int = 4000,
77-
learning_rate: float = 4e-2,
78-
momentum: float = 0.9,
79-
decay_iter: Union[int, List[int]] = 10000,
80-
decay_coeff: float = 0.5,
81-
epsilon: float = 0.1,
82-
dropout: int = 0.3,
83-
net_repeat: int = 1,
84-
endtoend: bool = True,
85-
verbose: bool = True,
70+
self,
71+
classifier: Union["CLASSIFIER_NEURALNETWORK_TYPE", List["CLASSIFIER_NEURALNETWORK_TYPE"]],
72+
target: np.ndarray,
73+
feature_layer: Union[Union[str, int], List[Union[str, int]]],
74+
opt: str = "adam",
75+
max_iter: int = 4000,
76+
learning_rate: float = 4e-2,
77+
momentum: float = 0.9,
78+
decay_iter: Union[int, List[int]] = 10000,
79+
decay_coeff: float = 0.5,
80+
epsilon: float = 0.1,
81+
dropout: int = 0.3,
82+
net_repeat: int = 1,
83+
endtoend: bool = True,
84+
verbose: bool = True,
8685
):
8786
"""
8887
Initialize an Feature Collision Clean-Label poisoning attack
@@ -105,8 +104,9 @@ def __init__(
105104
:param endtoend: True for end-to-end training. False for transfer learning.
106105
:param verbose: Show progress bars.
107106
"""
108-
self.subsistute_networks: List["CLASSIFIER_NEURALNETWORK_TYPE"] = \
109-
[classifier] if not isinstance(classifier, list) else classifier
107+
self.subsistute_networks: List["CLASSIFIER_NEURALNETWORK_TYPE"] = [classifier] if not isinstance(
108+
classifier, list
109+
) else classifier
110110

111111
super().__init__(classifier=self.subsistute_networks[0]) # type: ignore
112112
self.target = target
@@ -124,12 +124,7 @@ def __init__(
124124
self.verbose = verbose
125125
self._check_params()
126126

127-
def poison(
128-
self,
129-
x: np.ndarray,
130-
y: Optional[np.ndarray] = None,
131-
**kwargs
132-
) -> Tuple[np.ndarray, np.ndarray]:
127+
def poison(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
133128
"""
134129
Iteratively finds optimal attack points starting at values at x
135130
@@ -156,10 +151,10 @@ def forward(self):
156151
poison_batch = PoisonBatch([torch.from_numpy(np.copy(sample)).to(self.estimator.device) for sample in x])
157152
opt_method = self.opt.lower()
158153

159-
if opt_method == 'sgd':
154+
if opt_method == "sgd":
160155
logger.info("Using SGD to craft poison samples")
161156
optimizer = torch.optim.SGD(poison_batch.parameters(), lr=self.learning_rate, momentum=self.momentum)
162-
elif opt_method == 'adam':
157+
elif opt_method == "adam":
163158
logger.info("Using Adam to craft poison samples")
164159
optimizer = torch.optim.Adam(poison_batch.parameters(), lr=self.learning_rate, betas=(self.momentum, 0.999))
165160

@@ -178,11 +173,13 @@ def forward(self):
178173
for n, net in enumerate(self.subsistute_networks):
179174
# End to end training
180175
if self.endtoend:
181-
block_feats = [feat.detach()
182-
for feat in net.get_activations(x, layer=self.feature_layer, framework=True)]
176+
block_feats = [
177+
feat.detach() for feat in net.get_activations(x, layer=self.feature_layer, framework=True)
178+
]
183179
target_feat_list.append(block_feats)
184-
s_coeff = [torch.ones(n_poisons, 1).to(self.estimator.device) / n_poisons for _ in
185-
range(len(block_feats))]
180+
s_coeff = [
181+
torch.ones(n_poisons, 1).to(self.estimator.device) / n_poisons for _ in range(len(block_feats))
182+
]
186183
else:
187184
target_feat_list.append(net.get_activations(x, layer=self.feature_layer, framework=True).detach())
188185
s_coeff = torch.ones(n_poisons, 1).to(self.estimator.device) / n_poisons
@@ -192,23 +189,31 @@ def forward(self):
192189
for ite in trange(self.max_iter):
193190
if ite % self.decay_iter == 0 and ite != 0:
194191
for param_group in optimizer.param_groups:
195-
param_group['lr'] *= self.decay_coeff
196-
print("%s Iteration %d, Adjusted lr to %.2e" % (time.strftime("%Y-%m-%d %H:%M:%S"), ite,
197-
self.learning_rate))
192+
param_group["lr"] *= self.decay_coeff
193+
print(
194+
"%s Iteration %d, Adjusted lr to %.2e"
195+
% (time.strftime("%Y-%m-%d %H:%M:%S"), ite, self.learning_rate)
196+
)
198197

199198
poison_batch.zero_grad()
200-
total_loss = loss_from_center(self.subsistute_networks, target_feat_list, poison_batch, self.net_repeat,
201-
self.endtoend, self.feature_layer)
199+
total_loss = loss_from_center(
200+
self.subsistute_networks,
201+
target_feat_list,
202+
poison_batch,
203+
self.net_repeat,
204+
self.endtoend,
205+
self.feature_layer,
206+
)
202207
total_loss.backward()
203208
optimizer.step()
204209

205210
# clip the perturbations into the range
206-
perturb_range01 = torch.clamp((poison_batch.poison.data - base_tensor_batch),
207-
-self.epsilon,
208-
self.epsilon)
209-
perturbed_range01 = torch.clamp(base_range01_batch.data + perturb_range01.data,
210-
self.estimator.clip_values[0],
211-
self.estimator.clip_values[1])
211+
perturb_range01 = torch.clamp((poison_batch.poison.data - base_tensor_batch), -self.epsilon, self.epsilon)
212+
perturbed_range01 = torch.clamp(
213+
base_range01_batch.data + perturb_range01.data,
214+
self.estimator.clip_values[0],
215+
self.estimator.clip_values[1],
216+
)
212217
poison_batch.poison.data = perturbed_range01
213218

214219
if y is None:
@@ -226,7 +231,7 @@ def _check_params(self) -> None:
226231
if not isinstance(self.feature_layer, (str, int)):
227232
raise TypeError("Feature layer should be a string or int")
228233

229-
if self.opt.lower() not in ['adam', 'sgd']:
234+
if self.opt.lower() not in ["adam", "sgd"]:
230235
raise ValueError("Optimizer must be 'adam' or 'sgd'")
231236

232237
if 1 < self.momentum < 0:
@@ -244,8 +249,7 @@ def _check_params(self) -> None:
244249
if self.net_repeat < 1:
245250
raise ValueError("net_repeat must be at least 1")
246251

247-
valid_layer = 0 <= self.feature_layer < len(self.estimator.layer_names)
248-
if not valid_layer:
252+
if not 0 <= self.feature_layer < len(self.estimator.layer_names):
249253
raise ValueError("Invalid feature layer")
250254

251255
if 1 < self.decay_coeff < 0:
@@ -256,29 +260,33 @@ def get_poison_tuples(poison_batch, poison_label):
256260
"""
257261
Includes the labels
258262
"""
259-
poison = [poison_batch.poison.data[num_p].unsqueeze(0).detach().cpu().numpy()
260-
for num_p in range(poison_batch.poison.size(0))]
263+
poison = [
264+
poison_batch.poison.data[num_p].unsqueeze(0).detach().cpu().numpy()
265+
for num_p in range(poison_batch.poison.size(0))
266+
]
261267
return np.vstack(poison), poison_label
262268

263269

264-
def loss_from_center(subs_net_list, target_feat_list, poison_batch, net_repeat, end2end, feature_layer) -> \
265-
"torch.Tensor":
270+
def loss_from_center(
271+
subs_net_list, target_feat_list, poison_batch, net_repeat, end2end, feature_layer
272+
) -> "torch.Tensor":
266273
import torch
274+
267275
if end2end:
268276
loss = 0
269277
for net, center_feats in zip(subs_net_list, target_feat_list):
270278
if net_repeat > 1:
271-
poisons_feats_repeats = [net.get_activations(poison_batch(), layer=feature_layer, framework=True,
272-
input_tensor=True)
273-
for _ in range(net_repeat)]
279+
poisons_feats_repeats = [
280+
net.get_activations(poison_batch(), layer=feature_layer, framework=True) for _ in range(net_repeat)
281+
]
274282
BLOCK_NUM = len(poisons_feats_repeats[0])
275283
poisons_feats = []
276284
for block_idx in range(BLOCK_NUM):
277285
poisons_feats.append(
278-
sum([poisons_feat_r[block_idx] for poisons_feat_r in poisons_feats_repeats]) / net_repeat)
286+
sum([poisons_feat_r[block_idx] for poisons_feat_r in poisons_feats_repeats]) / net_repeat
287+
)
279288
elif net_repeat == 1:
280-
poisons_feats = net.get_activations(poison_batch(), layer=feature_layer, framework=True,
281-
input_tensor=True)
289+
poisons_feats = net.get_activations(poison_batch(), layer=feature_layer, framework=True)
282290
else:
283291
assert False, "net_repeat set to {}".format(net_repeat)
284292

@@ -295,8 +303,9 @@ def loss_from_center(subs_net_list, target_feat_list, poison_batch, net_repeat,
295303
else:
296304
loss = 0
297305
for net, center in zip(subs_net_list, target_feat_list):
298-
poisons = [net.get_activations(poison_batch(), layer=feature_layer, framework=True, input_tensor=True)
299-
for _ in range(net_repeat)]
306+
poisons = [
307+
net.get_activations(poison_batch(), layer=feature_layer, framework=True) for _ in range(net_repeat)
308+
]
300309
poisons = sum(poisons) / len(poisons)
301310

302311
diff = torch.mean(poisons, dim=0) - center

art/estimators/classification/classifier.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def replacement_function(self, *args, **kwargs):
5757
kwargs["x"] = np.array(kwargs["x"])
5858
else:
5959
if not isinstance(args[0], np.ndarray):
60-
if "input_tensor" not in kwargs or not kwargs["input_tensor"]:
61-
lst[0] = np.array(args[0])
60+
lst[0] = np.array(args[0])
6261

6362
if "y" in kwargs:
6463
if kwargs["y"] is not None and not isinstance(kwargs["y"], np.ndarray):
@@ -75,7 +74,7 @@ def replacement_function(self, *args, **kwargs):
7574
replacement_function.__name__ = "new_" + func_name
7675
return replacement_function
7776

78-
replacement_list_no_y = ["predict", "get_activations"]
77+
replacement_list_no_y = ["predict"]
7978
replacement_list_has_y = ["fit"]
8079

8180
for item in replacement_list_no_y:

art/estimators/classification/pytorch.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,11 @@ def loss_gradient(
716716
return grads
717717

718718
def get_activations(
719-
self, x: Union[np.ndarray, "torch.Tensor"], layer: Optional[Union[int, str]] = None, batch_size: int = 128,
720-
framework: bool = False, input_tensor: bool = False
719+
self,
720+
x: Union[np.ndarray, "torch.Tensor"],
721+
layer: Optional[Union[int, str]] = None,
722+
batch_size: int = 128,
723+
framework: bool = False,
721724
) -> np.ndarray:
722725
"""
723726
Return the output of the specified layer for input `x`. `layer` is specified by layer index (between 0 and
@@ -728,7 +731,6 @@ def get_activations(
728731
:param layer: Layer for computing the activations
729732
:param batch_size: Size of batches.
730733
:param framework: If true, return the intermediate tensor representation of the activation.
731-
:param input_tensor: Whether to expect the input to be a Tensor
732734
:return: The output of `layer`, where the first dimension is the batch size corresponding to `x`.
733735
"""
734736
import torch # lgtm [py/repeated-import]
@@ -751,8 +753,9 @@ def get_activations(
751753
raise TypeError("Layer must be of type str or int")
752754

753755
if framework:
754-
return self._model(torch.from_numpy(x).to(self._device))[layer_index] if not input_tensor else \
755-
self._model(x)[layer_index]
756+
if isinstance(x, torch.Tensor):
757+
return self._model(x)[layer_index]
758+
return self._model(torch.from_numpy(x).to(self._device))[layer_index]
756759

757760
# Run prediction with batch processing
758761
results = []

0 commit comments

Comments
 (0)