@@ -527,6 +527,7 @@ class CarliniLInfMethod(EvasionAttack):
527527 "initial_const" ,
528528 "largest_const" ,
529529 "const_factor" ,
530+ "batch_size" ,
530531 "verbose" ,
531532 ]
532533 _estimator_requirements = (BaseEstimator , ClassGradientsMixin )
@@ -542,6 +543,7 @@ def __init__(
542543 initial_const : float = 1e-5 ,
543544 largest_const : float = 20.0 ,
544545 const_factor : float = 2.0 ,
546+ batch_size : int = 1 ,
545547 verbose : bool = True ,
546548 ) -> None :
547549 """
@@ -559,6 +561,7 @@ def __init__(
559561 :param initial_const: The initial value of constant `c`.
560562 :param largest_const: The largest value of constant `c`.
561563 :param const_factor: The rate of increasing constant `c` with `const_factor > 1`, where smaller more accurate.
564+ :param batch_size: Size of the batch on which adversarial samples are generated.
562565 :param verbose: Show progress bars.
563566 """
564567 super ().__init__ (estimator = classifier )
@@ -571,6 +574,7 @@ def __init__(
571574 self .initial_const = initial_const
572575 self .largest_const = largest_const
573576 self .const_factor = const_factor
577+ self .batch_size = batch_size
574578 self .verbose = verbose
575579 self ._check_params ()
576580
@@ -591,7 +595,7 @@ def _loss(
591595 :param tau: Current limit `tau`.
592596 :return: A tuple of current predictions, total loss, logits loss and regularisation loss.
593597 """
594- z_predicted = self .estimator .predict (np .array (x_adv , dtype = ART_NUMPY_DTYPE ), batch_size = 1 )
598+ z_predicted = self .estimator .predict (np .array (x_adv , dtype = ART_NUMPY_DTYPE ), batch_size = self . batch_size )
595599 z_target = np .sum (z_predicted * target , axis = 1 )
596600 z_other = np .max (
597601 z_predicted * (1 - target ) + (np .min (z_predicted , axis = 1 ) - 1 )[:, np .newaxis ] * target ,
@@ -753,7 +757,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
753757
754758 # No labels provided, use model prediction as correct class
755759 if y is None :
756- y = get_labels_np_array (self .estimator .predict (x , batch_size = 1 ))
760+ y = get_labels_np_array (self .estimator .predict (x , batch_size = self . batch_size ))
757761
758762 if self .estimator .nb_classes == 2 and y .shape [1 ] == 1 :
759763 raise ValueError ( # pragma: no cover
@@ -830,6 +834,9 @@ def _check_params(self) -> None:
830834 if not isinstance (self .const_factor , (int , float )) or self .const_factor < 0 :
831835 raise ValueError ("The constant factor value must be a float and greater than 1." )
832836
837+ if not isinstance (self .batch_size , int ) or self .batch_size < 1 :
838+ raise ValueError ("The batch size must be an integer greater than zero." )
839+
833840
834841class CarliniL0Method (CarliniL2Method ):
835842 """
0 commit comments