@@ -97,6 +97,7 @@ def __init__(
9797 self .max_epochs = max_epochs
9898 self .batch_size = batch_size
9999 self .clip_values = clip_values
100+ self .initial_epoch = 0
100101
101102 if verbose is True :
102103 verbose = 1
@@ -175,7 +176,7 @@ def _weight_grad(classifier: TensorFlowV2Classifier, x: tf.Tensor, target: tf.Te
175176 with tf .GradientTape () as t : # pylint: disable=C0103
176177 t .watch (classifier .model .weights )
177178 output = classifier .model (x , training = False )
178- loss = classifier .model . compiled_loss (target , output )
179+ loss = classifier .loss_object (target , output )
179180 d_w = t .gradient (loss , classifier .model .weights )
180181 d_w = [w for w in d_w if w is not None ]
181182 d_w = tf .concat ([tf .reshape (d , [- 1 ]) for d in d_w ], 0 )
@@ -478,7 +479,11 @@ def __len__(self):
478479 PoisonDataset (x_poison , y_poison ), batch_size = self .batch_size , shuffle = False , num_workers = 1
479480 )
480481
481- epoch_iterator = trange (self .max_epochs ) if self .verbose > 0 else range (self .max_epochs )
482+ epoch_iterator = (
483+ trange (self .initial_epoch , self .max_epochs )
484+ if self .verbose > 0
485+ else range (self .initial_epoch , self .max_epochs )
486+ )
482487 for _ in epoch_iterator :
483488 batch_iterator = tqdm (trainloader ) if isinstance (self .verbose , int ) and self .verbose >= 2 else trainloader
484489 sum_loss = 0
@@ -536,6 +541,7 @@ def _poison__tensorflow(self, x_poison: np.ndarray, y_poison: np.ndarray) -> Tup
536541 [x_poison , y_poison , np .arange (len (y_poison ))],
537542 callbacks = callbacks ,
538543 batch_size = self .batch_size ,
544+ initial_epoch = self .initial_epoch ,
539545 epochs = self .max_epochs ,
540546 verbose = 0 ,
541547 )
0 commit comments