Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def train_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
)[0],
Comment on lines +71 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change correctly handles cases where the first input is None. However, it introduces a risk of a StopIteration error if all inputs in x are None. This can be difficult to debug, especially inside a tf.function.

A more robust approach would be to handle this edge case explicitly, for example by raising a ValueError with a clear message.

Also, this logic is duplicated in test_step. Consider extracting it into a private helper method to improve maintainability and ensure consistency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

)
if self.optimizer is not None:
loss = self.optimizer.scale_loss(loss)
Expand Down Expand Up @@ -96,7 +98,9 @@ def test_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
)[0],
Comment on lines +101 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the train_step, this change is vulnerable to a StopIteration error if all inputs are None. Explicitly handling this edge case would make the code more robust and prevent potential runtime crashes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

Expand Down