@@ -800,6 +800,7 @@ class TensorFlowV2Classifier(ClassGradientsMixin, ClassifierMixin, TensorFlowV2E
800800 + [
801801 "input_shape" ,
802802 "loss_object" ,
803+ "optimizer" ,
803804 "train_step" ,
804805 ]
805806 )
@@ -810,6 +811,7 @@ def __init__(
810811 nb_classes : int ,
811812 input_shape : Tuple [int , ...],
812813 loss_object : Optional ["tf.keras.losses.Loss" ] = None ,
814+ optimizer : Optional ["tf.keras.optimizers.Optimizer" ] = None ,
813815 train_step : Optional [Callable ] = None ,
814816 channels_first : bool = False ,
815817 clip_values : Optional ["CLIP_VALUES_TYPE" ] = None ,
@@ -824,10 +826,12 @@ def __init__(
824826 :param nb_classes: the number of classes in the classification task.
825827 :param input_shape: shape of one input for the classifier, e.g. for MNIST input_shape=(28, 28, 1).
826828 :param loss_object: The loss function for which to compute gradients. This parameter is applied for training
827- the model and computing gradients of the loss w.r.t. the input.
828- :type loss_object: `tf.keras.losses`
829+ the model and computing gradients of the loss w.r.t. the input.
830+ :param optimizer: The optimizer used to train the classifier.
829831 :param train_step: A function that applies a gradient update to the trainable variables with signature
830- train_step(model, images, labels).
832+ `train_step(model, images, labels)`. This will override the default training loop that uses the
833+ provided `loss_object` and `optimizer` parameters. It is recommended to use the `@tf.function`
834+ decorator, if possible, for efficient training.
831835 :param channels_first: Set channels first or last.
832836 :param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
833837 maximum values allowed for features. If floats are provided, these will be used as the range of all
@@ -853,6 +857,7 @@ def __init__(
853857 self .nb_classes = nb_classes
854858 self ._input_shape = input_shape
855859 self ._loss_object = loss_object
860+ self ._optimizer = optimizer
856861 self ._train_step = train_step
857862
858863 # Check if the loss function requires as input index labels instead of one-hot-encoded labels
@@ -879,6 +884,15 @@ def loss_object(self) -> "tf.keras.losses.Loss":
879884 """
880885 return self ._loss_object # type: ignore
881886
887+ @property
888+ def optimizer (self ) -> "tf.keras.optimizers.Optimizer" :
889+ """
890+ Return the optimizer.
891+
892+ :return: The optimizer.
893+ """
894+ return self ._optimizer # type: ignore
895+
882896 @property
883897 def train_step (self ) -> Callable :
884898 """
@@ -949,9 +963,27 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
949963 import tensorflow as tf
950964
951965 if self ._train_step is None : # pragma: no cover
952- raise TypeError (
953- "The training function `train_step` is required for fitting a model but it has not been " "defined."
954- )
966+ if self ._loss_object is None : # pragma: no cover
967+ raise TypeError (
968+ "A loss function `loss_object` or training function `train_step` is required for fitting the "
969+ "model, but it has not been defined."
970+ )
971+ if self ._optimizer is None : # pragma: no cover
972+ raise ValueError (
973+ "An optimizer `optimizer` or training function `train_step` is required for fitting the "
974+ "model, but it has not been defined."
975+ )
976+
977+ @tf .function
978+ def train_step (model , images , labels ):
979+ with tf .GradientTape () as tape :
980+ predictions = model (images , training = True )
981+ loss = self .loss_object (labels , predictions )
982+ gradients = tape .gradient (loss , model .trainable_variables )
983+ self .optimizer .apply_gradients (zip (gradients , model .trainable_variables ))
984+
985+ else :
986+ train_step = self ._train_step
955987
956988 y = check_and_transform_label_format (y , nb_classes = self .nb_classes )
957989
@@ -966,7 +998,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
966998
967999 for _ in range (nb_epochs ):
9681000 for images , labels in train_ds :
969- self . _train_step (self .model , images , labels )
1001+ train_step (self .model , images , labels )
9701002
9711003 def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , ** kwargs ) -> None :
9721004 """
@@ -982,9 +1014,27 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
9821014 from art .data_generators import TensorFlowV2DataGenerator
9831015
9841016 if self ._train_step is None : # pragma: no cover
985- raise TypeError (
986- "The training function `train_step` is required for fitting a model but it has not been " "defined."
987- )
1017+ if self ._loss_object is None : # pragma: no cover
1018+ raise TypeError (
1019+ "A loss function `loss_object` or training function `train_step` is required for fitting the "
1020+ "model, but it has not been defined."
1021+ )
1022+ if self ._optimizer is None : # pragma: no cover
1023+ raise ValueError (
1024+ "An optimizer `optimizer` or training function `train_step` is required for fitting the "
1025+ "model, but it has not been defined."
1026+ )
1027+
1028+ @tf .function
1029+ def train_step (model , images , labels ):
1030+ with tf .GradientTape () as tape :
1031+ predictions = model (images , training = True )
1032+ loss = self .loss_object (labels , predictions )
1033+ gradients = tape .gradient (loss , model .trainable_variables )
1034+ self .optimizer .apply_gradients (zip (gradients , model .trainable_variables ))
1035+
1036+ else :
1037+ train_step = self ._train_step
9881038
9891039 # Train directly in TensorFlow
9901040 from art .preprocessing .standardisation_mean_std .tensorflow import StandardisationMeanStdTensorFlow
@@ -1004,7 +1054,7 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
10041054 for i_batch , o_batch in generator .iterator :
10051055 if self ._reduce_labels :
10061056 o_batch = tf .math .argmax (o_batch , axis = 1 )
1007- self . _train_step (self ._model , i_batch , o_batch )
1057+ train_step (self ._model , i_batch , o_batch )
10081058 else :
10091059 # Fit a generic data generator through the API
10101060 super ().fit_generator (generator , nb_epochs = nb_epochs )
@@ -1263,6 +1313,7 @@ def clone_for_refitting(
12631313 clone ._train_step = self ._train_step # pylint: disable=W0212
12641314 clone ._reduce_labels = self ._reduce_labels # pylint: disable=W0212
12651315 clone ._loss_object = self ._loss_object # pylint: disable=W0212
1316+ clone ._optimizer = self ._optimizer # pylint: disable=W0212
12661317 return clone
12671318
12681319 def reset (self ) -> None :
@@ -1401,8 +1452,8 @@ def save(self, filename: str, path: Optional[str] = None) -> None:
14011452 def __repr__ (self ):
14021453 repr_ = (
14031454 f"{ self .__module__ + '.' + self .__class__ .__name__ } (model={ self ._model } , nb_classes={ self .nb_classes } , "
1404- f"input_shape={ self ._input_shape } , loss_object={ self ._loss_object } , train_step ={ self ._train_step } , "
1405- f"channels_first={ self .channels_first } , clip_values={ self .clip_values !r} , "
1455+ f"input_shape={ self ._input_shape } , loss_object={ self ._loss_object } , optimizer ={ self .optimizer } , "
1456+ f"train_step= { self . _train_step } , channels_first={ self .channels_first } , clip_values={ self .clip_values !r} , "
14061457 f"preprocessing_defences={ self .preprocessing_defences } , "
14071458 f"postprocessing_defences={ self .postprocessing_defences } , preprocessing={ self .preprocessing } )"
14081459 )
0 commit comments