diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index c223deff7e0..cd6410999dd 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -68,7 +68,9 @@ def train_step(self, data): ) self._loss_tracker.update_state( loss_module.unscale_loss_for_distribution(loss), - sample_weight=tf.shape(tree.flatten(x)[0])[0], + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], ) if self.optimizer is not None: loss = self.optimizer.scale_loss(loss) @@ -96,7 +98,9 @@ def test_step(self, data): ) self._loss_tracker.update_state( loss_module.unscale_loss_for_distribution(loss), - sample_weight=tf.shape(tree.flatten(x)[0])[0], + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], ) return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)