@@ -54,8 +54,8 @@ class AdversarialTrainerMadryPGD(Trainer):
5454 def __init__ (
5555 self ,
5656 classifier : "CLASSIFIER_LOSS_GRADIENTS_TYPE" ,
57- nb_epochs : int = 391 ,
58- batch_size : int = 128 ,
57+ nb_epochs : Optional [ int ] = 391 ,
58+ batch_size : Optional [ int ] = 128 ,
5959 eps : Union [int , float ] = 8 ,
6060 eps_step : Union [int , float ] = 2 ,
6161 max_iter : int = 7 ,
@@ -91,18 +91,41 @@ def __init__(
9191 self .trainer = AdversarialTrainer (classifier , self .attack , ratio = 1.0 ) # type: ignore
9292
9393 def fit ( # pylint: disable=W0221
94- self , x : np .ndarray , y : np .ndarray , validation_data : Optional [np .ndarray ] = None , ** kwargs
94+ self ,
95+ x : np .ndarray ,
96+ y : np .ndarray ,
97+ validation_data : Optional [np .ndarray ] = None ,
98+ batch_size : Optional [int ] = None ,
99+ nb_epochs : Optional [int ] = None ,
100+ ** kwargs
95101 ) -> None :
96102 """
97103 Train a model adversarially. See class documentation for more information on the exact procedure.
98104
99105 :param x: Training data.
100106 :param y: Labels for the training data.
101107 :param validation_data: Validation data.
108+ :param batch_size: Size of batches. Overwrites batch_size defined in __init__ if not None.
109+ :param nb_epochs: Number of epochs to use for trainings. Overwrites nb_epochs defined in __init__ if not None.
102110 :param kwargs: Dictionary of framework-specific arguments.
103111 """
112+ batch_size_fit : int
113+ if batch_size is not None :
114+ batch_size_fit = batch_size
115+ elif self .batch_size is not None :
116+ batch_size_fit = self .batch_size
117+ else :
118+ raise ValueError ("Please provide value for `batch_size`." )
119+
120+ if nb_epochs is not None :
121+ nb_epochs_fit : int = nb_epochs
122+ elif self .nb_epochs is not None :
123+ nb_epochs_fit = self .nb_epochs
124+ else :
125+ raise ValueError ("Please provide value for `nb_epochs`." )
126+
104127 self .trainer .fit (
105- x , y , validation_data = validation_data , nb_epochs = self . nb_epochs , batch_size = self . batch_size , ** kwargs
128+ x , y , validation_data = validation_data , nb_epochs = nb_epochs_fit , batch_size = batch_size_fit , ** kwargs
106129 )
107130
108131 def get_classifier (self ) -> "CLASSIFIER_LOSS_GRADIENTS_TYPE" :
0 commit comments