Skip to content

Commit bbeae7d

Browse files
authored
Merge branch 'dev_1.14.1' into ibp_loss_weighting_fix
2 parents e87ed91 + 6a8d734 commit bbeae7d

File tree

6 files changed

+673
-662
lines changed

6 files changed

+673
-662
lines changed

art/attacks/poisoning/bad_det/bad_det_rma.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import Dict, List, Tuple
26+
from typing import Dict, List, Tuple, Optional
2727

2828
import numpy as np
2929
from tqdm.auto import tqdm
@@ -54,7 +54,7 @@ class BadDetRegionalMisclassificationAttack(PoisoningAttackObjectDetector):
5454
def __init__(
5555
self,
5656
backdoor: PoisoningAttackBackdoor,
57-
class_source: int = 0,
57+
class_source: Optional[int] = None,
5858
class_target: int = 1,
5959
percent_poison: float = 0.3,
6060
channels_first: bool = False,
@@ -64,7 +64,8 @@ def __init__(
6464
Creates a new BadDet Regional Misclassification Attack
6565
6666
:param backdoor: the backdoor chosen for this attack.
67-
:param class_source: The source class from which triggers were selected.
67+
:param class_source: The source class (optionally) from which triggers were selected. If no source is
68+
provided, then all classes will be poisoned.
6869
:param class_target: The target label to which the poisoned model needs to misclassify.
6970
:param percent_poison: The ratio of samples to poison in the source class, with range [0, 1].
7071
:param channels_first: Set channels first or last.
@@ -116,7 +117,7 @@ def poison( # pylint: disable=W0221
116117
target_dict = {k: v.copy() for k, v in y_i.items()}
117118
y_poison.append(target_dict)
118119

119-
if self.class_source in y_i["labels"]:
120+
if self.class_source is None or self.class_source in y_i["labels"]:
120121
source_indices.append(i)
121122

122123
# select indices of samples to poison
@@ -130,7 +131,7 @@ def poison( # pylint: disable=W0221
130131
labels = y_poison[i]["labels"]
131132

132133
for j, (box, label) in enumerate(zip(boxes, labels)):
133-
if label == self.class_source:
134+
if self.class_source is None or label == self.class_source:
134135
# extract the bounding box from the image
135136
x_1, y_1, x_2, y_2 = box.astype(int)
136137
bounding_box = image[y_1:y_2, x_1:x_2, :]

0 commit comments

Comments
 (0)