2323from __future__ import absolute_import , division , print_function , unicode_literals
2424
2525import logging
26- from typing import Dict , List , Tuple
26+ from typing import Dict , List , Tuple , Optional
2727
2828import numpy as np
2929from tqdm .auto import tqdm
@@ -54,7 +54,7 @@ class BadDetRegionalMisclassificationAttack(PoisoningAttackObjectDetector):
5454 def __init__ (
5555 self ,
5656 backdoor : PoisoningAttackBackdoor ,
57- class_source : int = 0 ,
57+ class_source : Optional [ int ] = None ,
5858 class_target : int = 1 ,
5959 percent_poison : float = 0.3 ,
6060 channels_first : bool = False ,
@@ -64,7 +64,8 @@ def __init__(
6464 Creates a new BadDet Regional Misclassification Attack
6565
6666 :param backdoor: the backdoor chosen for this attack.
67- :param class_source: The source class from which triggers were selected.
67+ :param class_source: The source class (optionally) from which triggers were selected. If no source is
68+ provided, then all classes will be poisoned.
6869 :param class_target: The target label to which the poisoned model needs to misclassify.
6970 :param percent_poison: The ratio of samples to poison in the source class, with range [0, 1].
7071 :param channels_first: Set channels first or last.
@@ -116,7 +117,7 @@ def poison( # pylint: disable=W0221
116117 target_dict = {k : v .copy () for k , v in y_i .items ()}
117118 y_poison .append (target_dict )
118119
119- if self .class_source in y_i ["labels" ]:
120+ if self .class_source is None or self . class_source in y_i ["labels" ]:
120121 source_indices .append (i )
121122
122123 # select indices of samples to poison
@@ -130,7 +131,7 @@ def poison( # pylint: disable=W0221
130131 labels = y_poison [i ]["labels" ]
131132
132133 for j , (box , label ) in enumerate (zip (boxes , labels )):
133- if label == self .class_source :
134+ if self . class_source is None or label == self .class_source :
134135 # extract the bounding box from the image
135136 x_1 , y_1 , x_2 , y_2 = box .astype (int )
136137 bounding_box = image [y_1 :y_2 , x_1 :x_2 , :]
0 commit comments