@@ -85,12 +85,16 @@ def replacement_function(self, *args, **kwargs):
8585 setattr (cls , item , new_function )
8686
8787
88- class Attack (abc .ABC , metaclass = InputFilter ):
88+ class Attack (abc .ABC ):
8989 """
9090 Abstract base class for all attack abstract base classes.
9191 """
9292
9393 attack_params : List [str ] = list ()
94+ # The _estimator_requirements define the requirements an estimator must satisfy to be used as a target for an
95+ # attack. They should be a tuple of requirements, where each requirement is either a class the estimator must
96+ # inherit from, or a tuple of classes which define a union, i.e. the estimator must inherit from at least one class
97+ # in the requirement tuple.
9498 _estimator_requirements : Optional [Union [Tuple [Any , ...], Tuple [()]]] = None
9599
96100 def __init__ (
@@ -111,7 +115,7 @@ def __init__(
111115 if self .estimator_requirements is None :
112116 raise ValueError ("Estimator requirements have not been defined in `_estimator_requirements`." )
113117
114- if not all ( t in type ( estimator ). __mro__ for t in self .estimator_requirements ):
118+ if not self .is_estimator_valid ( estimator ):
115119 raise EstimatorError (self .__class__ , self .estimator_requirements , estimator )
116120
117121 self ._estimator = estimator
@@ -155,6 +159,24 @@ def _check_params(self) -> None:
155159 if not isinstance (self .tensor_board , (bool , str )):
156160 raise ValueError ("The argument `tensor_board` has to be either of type bool or str." )
157161
162+ def is_estimator_valid (self , estimator ) -> bool :
163+ """
164+ Checks if the given estimator satisfies the requirements for this attack.
165+
166+ :param estimator: The estimator to check.
167+ :return: True if the estimator is valid for the attack.
168+ """
169+
170+ for req in self .estimator_requirements :
171+ # A requirement is either a class which the estimator must inherit from, or a tuple of classes and the
172+ # estimator is required to inherit from at least one of the classes
173+ if isinstance (req , tuple ):
174+ if all (p not in type (estimator ).__mro__ for p in req ):
175+ return False
176+ elif req not in type (estimator ).__mro__ :
177+ return False
178+ return True
179+
158180
159181class EvasionAttack (Attack ):
160182 """
@@ -175,7 +197,7 @@ def generate( # lgtm [py/inheritance/incorrect-overridden-signature]
175197
176198 :param x: An array with the original inputs to be attacked.
177199 :param y: Correct labels or target labels for `x`, depending if the attack is targeted
178- or not. This parameter is only used by some of the attacks.
200+ or not. This parameter is only used by some of the attacks.
179201 :return: An array holding the adversarial examples.
180202 """
181203 raise NotImplementedError
@@ -373,7 +395,7 @@ class MembershipInferenceAttack(InferenceAttack):
373395 Abstract base class for membership inference attack classes.
374396 """
375397
376- def __init__ (self , estimator : Union [ "CLASSIFIER_TYPE" ] ):
398+ def __init__ (self , estimator ):
377399 """
378400 :param estimator: A trained estimator targeted for inference attack.
379401 :type estimator: :class:`.art.estimators.estimator.BaseEstimator`
0 commit comments