@@ -34,11 +34,15 @@ class Postprocessor(abc.ABC):
3434
3535 params : List [str ] = []
3636
37- def __init__ (self ) -> None :
37+ def __init__ (self , is_fitted : bool = False , apply_fit : bool = True , apply_predict : bool = True ) -> None :
3838 """
3939 Create a postprocessing object.
40+
41+ Optionally, set attributes.
4042 """
41- self ._is_fitted = False
43+ self ._is_fitted = bool (is_fitted )
44+ self ._apply_fit = bool (apply_fit )
45+ self ._apply_predict = bool (apply_predict )
4246
4347 @property
4448 def is_fitted (self ) -> bool :
@@ -50,24 +54,22 @@ def is_fitted(self) -> bool:
5054 return self ._is_fitted
5155
5256 @property
53- @abc .abstractmethod
5457 def apply_fit (self ) -> bool :
5558 """
5659 Property of the defence indicating if it should be applied at training time.
5760
5861 :return: `True` if the defence should be applied when fitting a model, `False` otherwise.
5962 """
60- raise NotImplementedError
63+ return self . _apply_fit
6164
6265 @property
63- @abc .abstractmethod
6466 def apply_predict (self ) -> bool :
6567 """
6668 Property of the defence indicating if it should be applied at test time.
6769
6870 :return: `True` if the defence should be applied at prediction time, `False` otherwise.
6971 """
70- raise NotImplementedError
72+ return self . _apply_predict
7173
7274 @abc .abstractmethod
7375 def __call__ (self , preds : np .ndarray ) -> np .ndarray :
@@ -79,15 +81,14 @@ def __call__(self, preds: np.ndarray) -> np.ndarray:
7981 """
8082 raise NotImplementedError
8183
82- @abc .abstractmethod
8384 def fit (self , preds : np .ndarray , ** kwargs ) -> None :
8485 """
8586 Fit the parameters of the postprocessor if it has any.
8687
8788 :param preds: Training set to fit the postprocessor.
8889 :param kwargs: Other parameters.
8990 """
90- raise NotImplementedError
91+ pass
9192
9293 def set_params (self , ** kwargs ) -> None :
9394 """
0 commit comments