Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def evaluate(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True
):
# Create an iterator that yields batches of input data.
epoch_iterator = JAXEpochIterator(
Expand Down Expand Up @@ -718,7 +718,8 @@ def append_to_outputs(batch_outputs, outputs):
# during predict(), but it's allowed.
"non_trainable_variables": non_trainable_variables,
}
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)

# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_predict_batch_end(
Expand All @@ -731,7 +732,11 @@ def append_to_outputs(batch_outputs, outputs):
self.jax_state_sync()
callbacks.on_predict_end()
self._jax_state = None
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
return outputs

def train_on_batch(
self,
Expand Down
11 changes: 8 additions & 3 deletions keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def fit(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True
):
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
Expand Down Expand Up @@ -214,12 +214,17 @@ def append_to_outputs(batch_outputs, outputs):
for begin_step, end_step, data in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
return outputs

@traceback_utils.filter_traceback
def evaluate(
Expand Down
21 changes: 14 additions & 7 deletions keras/src/backend/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def fit(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True
):
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
Expand Down Expand Up @@ -213,15 +213,22 @@ def append_to_outputs(batch_outputs, outputs):
self.stop_predicting = False
callbacks.on_predict_begin()
outputs = None
for begin_step, end_step, data in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
for _, iterator in epoch_iterator:
callbacks.on_predict_batch_begin(epoch_iterator.current_step)
batch_outputs = self.predict_function(iterator)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(
epoch_iterator.current_step, {"outputs": batch_outputs}
)
if self.stop_predicting:
break
callbacks.on_predict_end()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
return outputs

@traceback_utils.filter_traceback
def evaluate(
Expand Down
17 changes: 11 additions & 6 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def evaluate(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True
):
# Create an iterator that yields batches of input data.
epoch_iterator = TFEpochIterator(
Expand Down Expand Up @@ -586,17 +586,22 @@ def get_data(iterator):
callbacks.on_predict_batch_begin(begin_step)
data = get_data(iterator)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(
end_step, {"outputs": batch_outputs}
)
if self.stop_predicting:
break
callbacks.on_predict_end()
outputs = tree.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tree.map_structure(convert_to_np_if_not_ragged, outputs)
if accumulate:
if outputs is None:
return None
outputs = tree.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tree.map_structure(convert_to_np_if_not_ragged, outputs)
return outputs

def train_on_batch(
self,
Expand Down
13 changes: 9 additions & 4 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def evaluate(

@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True
):
# Create an iterator that yields batches of input data.
epoch_iterator = TorchEpochIterator(
Expand Down Expand Up @@ -442,13 +442,18 @@ def append_to_outputs(batch_outputs, outputs):
for begin_step, end_step, data in epoch_iterator:
callbacks.on_predict_batch_begin(begin_step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
if accumulate:
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
outputs = tree.map_structure(backend.convert_to_numpy, outputs)
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
if accumulate:
if outputs is None:
return None
outputs = tree.map_structure(backend.convert_to_numpy, outputs)
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
return outputs

def train_on_batch(
self,
Expand Down
8 changes: 6 additions & 2 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ def evaluate(
raise NotImplementedError

def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None, accumulate=True
):
"""Generates output predictions for the input samples.

Expand Down Expand Up @@ -858,9 +858,13 @@ def predict(
repeating dataset, it will run indefinitely.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during prediction.
accumulate: Boolean. Whether to accumulate predictions in memory.
If `False`, predictions are not returned and must be handled
via callbacks to avoid memory issues with large datasets.
Defaults to `True`.

Returns:
NumPy array(s) of predictions.
NumPy array(s) of predictions if `accumulate=True`, otherwise `None`.
"""
raise NotImplementedError

Expand Down
Loading