@@ -54,8 +54,8 @@ class AdversarialTrainerMadryPGD(Trainer):
5454 def __init__ (
5555 self ,
5656 classifier : "CLASSIFIER_LOSS_GRADIENTS_TYPE" ,
57- nb_epochs : int = 391 ,
58- batch_size : int = 128 ,
57+ nb_epochs : Optional [ int ] = 391 ,
58+ batch_size : Optional [ int ] = 128 ,
5959 eps : Union [int , float ] = 8 ,
6060 eps_step : Union [int , float ] = 2 ,
6161 max_iter : int = 7 ,
@@ -91,18 +91,36 @@ def __init__(
9191 self .trainer = AdversarialTrainer (classifier , self .attack , ratio = 1.0 ) # type: ignore
9292
9393 def fit ( # pylint: disable=W0221
94- self , x : np .ndarray , y : np .ndarray , validation_data : Optional [np .ndarray ] = None , ** kwargs
94+ self ,
95+ x : np .ndarray ,
96+ y : np .ndarray ,
97+ validation_data : Optional [np .ndarray ] = None ,
98+ batch_size : Optional [int ] = None ,
99+ nb_epochs : Optional [int ] = None ,
100+ ** kwargs
95101 ) -> None :
96102 """
97103 Train a model adversarially. See class documentation for more information on the exact procedure.
98104
99105 :param x: Training data.
100106 :param y: Labels for the training data.
101107 :param validation_data: Validation data.
108+ :param batch_size: Size of batches. Overwrites batch_size defined in __init__ if not None.
109+ :param nb_epochs: Number of epochs to use for trainings. Overwrites nb_epochs defined in __init__ if not None.
102110 :param kwargs: Dictionary of framework-specific arguments.
103111 """
112+ if batch_size is not None :
113+ batch_size_fit = batch_size
114+ else :
115+ batch_size_fit = self .batch_size
116+
117+ if nb_epochs is not None :
118+ nb_epochs_fit = nb_epochs
119+ else :
120+ nb_epochs_fit = self .nb_epochs
121+
104122 self .trainer .fit (
105- x , y , validation_data = validation_data , nb_epochs = self . nb_epochs , batch_size = self . batch_size , ** kwargs
123+ x , y , validation_data = validation_data , nb_epochs = nb_epochs_fit , batch_size = batch_size_fit , ** kwargs
106124 )
107125
108126 def get_classifier (self ) -> "CLASSIFIER_LOSS_GRADIENTS_TYPE" :
0 commit comments