Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,13 @@ def evaluate(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self,
x,
batch_size=None,
verbose="auto",
steps=None,
callbacks=None,
accumulate=True,
):
# Create an iterator that yields batches of input data.
epoch_iterator = JAXEpochIterator(
Expand Down Expand Up @@ -718,7 +724,8 @@ def append_to_outputs(batch_outputs, outputs):
# during predict(), but it's allowed.
"non_trainable_variables": non_trainable_variables,
}
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)

# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_predict_batch_end(
Expand All @@ -731,7 +738,13 @@ def append_to_outputs(batch_outputs, outputs):
self.jax_state_sync()
callbacks.on_predict_end()
self._jax_state = None
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
return tree.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)
return outputs

def train_on_batch(
self,
Expand Down
19 changes: 16 additions & 3 deletions keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,13 @@ def fit(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self,
x,
batch_size=None,
verbose="auto",
steps=None,
callbacks=None,
accumulate=True,
):
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
Expand Down Expand Up @@ -214,12 +220,19 @@ def append_to_outputs(batch_outputs, outputs):
for begin_step, end_step, data in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
return tree.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)
return outputs

@traceback_utils.filter_traceback
def evaluate(
Expand Down
23 changes: 18 additions & 5 deletions keras/src/backend/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ def fit(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self,
x,
batch_size=None,
verbose="auto",
steps=None,
callbacks=None,
accumulate=True,
):
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
Expand Down Expand Up @@ -213,15 +219,22 @@ def append_to_outputs(batch_outputs, outputs):
self.stop_predicting = False
callbacks.on_predict_begin()
outputs = None
for begin_step, end_step, data in epoch_iterator.enumerate_epoch():
for begin_step, end_step, iterator in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
batch_outputs = self.predict_function(iterator)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
return tree.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)
return outputs

@traceback_utils.filter_traceback
def evaluate(
Expand Down
23 changes: 17 additions & 6 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,13 @@ def evaluate(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self,
x,
batch_size=None,
verbose="auto",
steps=None,
callbacks=None,
accumulate=True,
):
# Create an iterator that yields batches of input data.
epoch_iterator = TFEpochIterator(
Expand Down Expand Up @@ -586,17 +592,22 @@ def get_data(iterator):
callbacks.on_predict_batch_begin(begin_step)
data = get_data(iterator)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(
end_step, {"outputs": batch_outputs}
)
if self.stop_predicting:
break
callbacks.on_predict_end()
outputs = tree.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tree.map_structure(convert_to_np_if_not_ragged, outputs)
if accumulate:
if outputs is None:
return None
outputs = tree.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tree.map_structure(convert_to_np_if_not_ragged, outputs)
return outputs

def train_on_batch(
self,
Expand Down
21 changes: 17 additions & 4 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,13 @@ def evaluate(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self,
x,
batch_size=None,
verbose="auto",
steps=None,
callbacks=None,
accumulate=True,
):
# Create an iterator that yields batches of input data.
epoch_iterator = TorchEpochIterator(
Expand Down Expand Up @@ -442,13 +448,20 @@ def append_to_outputs(batch_outputs, outputs):
for begin_step, end_step, data in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
outputs = tree.map_structure(backend.convert_to_numpy, outputs)
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
outputs = tree.map_structure(backend.convert_to_numpy, outputs)
return tree.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)
return outputs

def train_on_batch(
self,
Expand Down
15 changes: 13 additions & 2 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,13 @@ def evaluate(
raise NotImplementedError

def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self,
x,
batch_size=None,
verbose="auto",
steps=None,
callbacks=None,
accumulate=True,
):
"""Generates output predictions for the input samples.

Expand Down Expand Up @@ -858,9 +864,14 @@ def predict(
repeating dataset, it will run indefinitely.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during prediction.
accumulate: Boolean. Whether to accumulate predictions in memory.
If `False`, predictions are not returned and must be handled
via callbacks to avoid memory issues with large datasets.
Defaults to `True`.

Returns:
NumPy array(s) of predictions.
NumPy array(s) of predictions if `accumulate=True`,
otherwise `None`.
"""
raise NotImplementedError

Expand Down
87 changes: 87 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,6 +2207,93 @@ def test_predict_dropout(self):
out3 = model.predict_on_batch(np.ones((2, 20)))
self.assertGreater(5, np.sum(np.abs(out2 - out3)))

def test_predict_accumulate_parameter(self):
# Test that `predict` with accumulate=True/False works correctly
model = ExampleModel(units=3)
x = np.ones((10, 4))

# Test accumulate=True (default behavior)
outputs_accumulated = model.predict(x, batch_size=2, accumulate=True)
self.assertIsInstance(outputs_accumulated, np.ndarray)
self.assertEqual(outputs_accumulated.shape, (10, 3))
self.assertAllClose(outputs_accumulated, 4 * np.ones((10, 3)))

# Test accumulate=False with callback to capture outputs
class OutputCaptureCallback(Callback):
def __init__(self):
super().__init__()
self.outputs = []

def on_predict_batch_end(self, batch, logs=None):
if logs and "outputs" in logs:
self.outputs.append(logs["outputs"])

callback = OutputCaptureCallback()
outputs_none = model.predict(
x, batch_size=2, accumulate=False, callbacks=[callback]
)

# Verify accumulate=False returns None
self.assertIsNone(outputs_none)

# Verify callback captured the correct number of batches
self.assertEqual(
len(callback.outputs), 5
) # 10 samples / 2 batch_size = 5 batches

# Verify callback outputs match accumulated outputs when concatenated
concatenated_outputs = np.concatenate(callback.outputs, axis=0)
self.assertAllClose(outputs_accumulated, concatenated_outputs)

def test_predict_accumulate_parameter_multi_output(self):
# Test accumulate parameter with multi-output model
inputs = layers.Input((4,))
output1 = layers.Dense(3, name="out1")(inputs)
output2 = layers.Dense(2, name="out2")(inputs)
model = models.Model(inputs=inputs, outputs=[output1, output2])

x = np.ones((8, 4))

# Test accumulate=True (default behavior)
outputs_accumulated = model.predict(x, batch_size=2, accumulate=True)
self.assertIsInstance(outputs_accumulated, list)
self.assertEqual(len(outputs_accumulated), 2)
self.assertEqual(outputs_accumulated[0].shape, (8, 3))
self.assertEqual(outputs_accumulated[1].shape, (8, 2))

# Test accumulate=False with callback
class OutputCaptureCallback(Callback):
def __init__(self):
super().__init__()
self.outputs = []

def on_predict_batch_end(self, batch, logs=None):
if logs and "outputs" in logs:
self.outputs.append(logs["outputs"])

callback = OutputCaptureCallback()
outputs_none = model.predict(
x, batch_size=2, accumulate=False, callbacks=[callback]
)

# Verify accumulate=False returns None
self.assertIsNone(outputs_none)

# Verify callback captured the correct outputs
self.assertEqual(
len(callback.outputs), 4
) # 8 samples / 2 batch_size = 4 batches

# Verify callback outputs match accumulated outputs when concatenated
concatenated_outputs_1 = np.concatenate(
[out[0] for out in callback.outputs], axis=0
)
concatenated_outputs_2 = np.concatenate(
[out[1] for out in callback.outputs], axis=0
)
self.assertAllClose(outputs_accumulated[0], concatenated_outputs_1)
self.assertAllClose(outputs_accumulated[1], concatenated_outputs_2)

@pytest.mark.requires_trainable_backend
def test_recompile(self):
model = ExampleModel(units=3)
Expand Down