2222
2323import abc
2424import logging
25- from typing import List , Optional , Tuple , TYPE_CHECKING
25+ from typing import Any , List , Optional , Tuple , Union , TYPE_CHECKING
2626
2727import numpy as np
2828
2929from art .exceptions import EstimatorError
3030
3131if TYPE_CHECKING :
32- from art .estimators . classification . classifier import Classifier
32+ from art .utils import CLASSIFIER_TYPE
3333
3434logger = logging .getLogger (__name__ )
3535
@@ -90,11 +90,15 @@ class Attack(abc.ABC, metaclass=input_filter):
9090 """
9191
9292 attack_params : List [str ] = list ()
93+ _estimator_requirements : Optional [Union [Tuple [Any , ...], Tuple [()]]] = None
9394
9495 def __init__ (self , estimator ):
9596 """
9697 :param estimator: An estimator.
9798 """
99+ if self .estimator_requirements is None :
100+ raise ValueError ("Estimator requirements have not been defined in `_estimator_requirements`." )
101+
98102 if not all (t in type (estimator ).__mro__ for t in self .estimator_requirements ):
99103 raise EstimatorError (self .__class__ , self .estimator_requirements , estimator )
100104
@@ -128,6 +132,10 @@ class EvasionAttack(Attack):
128132 Abstract base class for evasion attack classes.
129133 """
130134
135+ def __init__ (self , ** kwargs ) -> None :
136+ self ._targeted = False
137+ super ().__init__ (** kwargs )
138+
131139 @abc .abstractmethod
132140 def generate ( # lgtm [py/inheritance/incorrect-overridden-signature]
133141 self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs
@@ -143,34 +151,77 @@ def generate( # lgtm [py/inheritance/incorrect-overridden-signature]
143151 """
144152 raise NotImplementedError
145153
154+ @property
155+ def targeted (self ) -> bool :
156+ """
157+ Return Boolean if attack is targeted. Return None if not applicable.
158+ """
159+ return self ._targeted
160+
161+ @targeted .setter
162+ def targeted (self , targeted ) -> None :
163+ self ._targeted = targeted
164+
146165
147166class PoisoningAttack (Attack ):
148167 """
149168 Abstract base class for poisoning attack classes
150169 """
151170
152- def __init__ (self , classifier ) -> None :
171+ def __init__ (self , classifier : Optional ["CLASSIFIER_TYPE" ]) -> None :
172+ """
173+ :param classifier: A trained classifier (or none if no classifier is needed)
174+ """
175+ super ().__init__ (classifier )
176+
177+ @abc .abstractmethod
178+ def poison (self , x : np .ndarray , y = Optional [np .ndarray ], ** kwargs ) -> Tuple [np .ndarray , np .ndarray ]:
179+ """
180+ Generate poisoning examples and return them as an array. This method should be overridden by all concrete
181+ poisoning attack implementations.
182+
183+ :param x: An array with the original inputs to be attacked.
184+ :param y: Target labels for `x`. Untargeted attacks set this value to None.
185+ :return: An tuple holding the (poisoning examples, poisoning labels).
186+ """
187+ raise NotImplementedError
188+
189+
190+ class PoisoningAttackTransformer (PoisoningAttack ):
191+ """
192+ Abstract base class for poisoning attack classes that return a transformed classifier.
193+ These attacks have an additional method, `poison_estimator`, that returns the poisoned classifier.
194+ """
195+
196+ def __init__ (self , classifier : Optional ["CLASSIFIER_TYPE" ], ** kwargs ) -> None :
153197 """
154198 :param classifier: A trained classifier (or none if no classifier is needed)
155- :type classifier: `art.estimators.classification.Classifier` or `None`
156199 """
157200 super ().__init__ (classifier )
158201
159202 @abc .abstractmethod
160- def poison (self , x , y = None , ** kwargs ):
203+ def poison (self , x : np . ndarray , y = Optional [ np . ndarray ] , ** kwargs ) -> Tuple [ np . ndarray , np . ndarray ] :
161204 """
162205 Generate poisoning examples and return them as an array. This method should be overridden by all concrete
163206 poisoning attack implementations.
164207
165208 :param x: An array with the original inputs to be attacked.
166- :type x: `np.ndarray`
167209 :param y: Target labels for `x`. Untargeted attacks set this value to None.
168- :type y: `np.ndarray`
169210 :return: An tuple holding the (poisoning examples, poisoning labels).
170211 :rtype: `(np.ndarray, np.ndarray)`
171212 """
172213 raise NotImplementedError
173214
215+ @abc .abstractmethod
216+ def poison_estimator (self , x : np .ndarray , y : np .ndarray , ** kwargs ) -> "CLASSIFIER_TYPE" :
217+ """
218+ Returns a poisoned version of the classifier used to initialize the attack
219+ :param x: Training data
220+ :param y: Training labels
221+ :return: A poisoned classifier
222+ """
223+ raise NotImplementedError
224+
174225
175226class PoisoningAttackBlackBox (PoisoningAttack ):
176227 """
@@ -221,7 +272,7 @@ class ExtractionAttack(Attack):
221272 """
222273
223274 @abc .abstractmethod
224- def extract (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> "Classifier " :
275+ def extract (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> "CLASSIFIER_TYPE " :
225276 """
226277 Extract models and return them as an ART classifier. This method should be overridden by all concrete extraction
227278 attack implementations.
@@ -292,7 +343,7 @@ def set_params(self, **kwargs) -> None:
292343 Take in a dictionary of parameters and applies attack-specific checks before saving them as attributes.
293344 """
294345 # Save attack-specific parameters
295- super (AttributeInferenceAttack , self ).set_params (** kwargs )
346+ super ().set_params (** kwargs )
296347 self ._check_params ()
297348
298349 def _check_params (self ) -> None :
0 commit comments