2525from tqdm import trange
2626
2727from art .config import ART_NUMPY_DTYPE
28- from art .defences .postprocessor .postprocessor import Postprocessor
29- from art .defences .preprocessor .preprocessor import Preprocessor
3028from art .utils import Deprecated , deprecated , deprecated_keyword_arg
3129
3230if TYPE_CHECKING :
3331 # pylint: disable=R0401
3432 from art .utils import CLIP_VALUES_TYPE , PREPROCESSING_TYPE
3533 from art .data_generators import DataGenerator
3634 from art .metrics .verification_decisions_trees import Tree
35+ from art .defences .postprocessor .postprocessor import Postprocessor
36+ from art .defences .preprocessor .preprocessor import Preprocessor
3737
3838
3939class BaseEstimator (ABC ):
@@ -54,8 +54,8 @@ def __init__(
5454 self ,
5555 model = None ,
5656 clip_values : Optional ["CLIP_VALUES_TYPE" ] = None ,
57- preprocessing_defences : Union [Preprocessor , List [Preprocessor ], None ] = None ,
58- postprocessing_defences : Union [Postprocessor , List [Postprocessor ], None ] = None ,
57+ preprocessing_defences : Union [" Preprocessor" , List [" Preprocessor" ], None ] = None ,
58+ postprocessing_defences : Union [" Postprocessor" , List [" Postprocessor" ], None ] = None ,
5959 preprocessing : "PREPROCESSING_TYPE" = (0 , 1 ),
6060 ):
6161 """
@@ -75,14 +75,14 @@ def __init__(
7575 self ._model = model
7676 self ._clip_values = clip_values
7777
78- self .preprocessing_defences : Optional [List [Preprocessor ]]
79- if isinstance (preprocessing_defences , Preprocessor ):
78+ self .preprocessing_defences : Optional [List [" Preprocessor" ]]
79+ if isinstance (preprocessing_defences , " Preprocessor" ):
8080 self .preprocessing_defences = [preprocessing_defences ]
8181 else :
8282 self .preprocessing_defences = preprocessing_defences
8383
84- self .postprocessing_defences : Optional [List [Postprocessor ]]
85- if isinstance (postprocessing_defences , Postprocessor ):
84+ self .postprocessing_defences : Optional [List [" Postprocessor" ]]
85+ if isinstance (postprocessing_defences , " Postprocessor" ):
8686 self .postprocessing_defences = [postprocessing_defences ]
8787 else :
8888 self .postprocessing_defences = postprocessing_defences
@@ -133,7 +133,7 @@ def _check_params(self) -> None:
133133
134134 if isinstance (self .preprocessing_defences , list ):
135135 for preproc_defence in self .preprocessing_defences :
136- if not isinstance (preproc_defence , Preprocessor ):
136+ if not isinstance (preproc_defence , " Preprocessor" ):
137137 raise ValueError (
138138 "All preprocessing defences have to be instance of "
139139 "art.defences.preprocessor.preprocessor.Preprocessor."
@@ -147,7 +147,7 @@ def _check_params(self) -> None:
147147 )
148148 if isinstance (self .postprocessing_defences , list ):
149149 for postproc_defence in self .postprocessing_defences :
150- if not isinstance (postproc_defence , Postprocessor ):
150+ if not isinstance (postproc_defence , " Postprocessor" ):
151151 raise ValueError (
152152 "All postprocessing defences have to be instance of "
153153 "art.defences.postprocessor.postprocessor.Postprocessor."
0 commit comments