diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 5f01505c2d4..a68ef15444c 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -638,7 +638,13 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = JAXEpochIterator( @@ -718,7 +724,8 @@ def append_to_outputs(batch_outputs, outputs): # during predict(), but it's allowed. "non_trainable_variables": non_trainable_variables, } - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) # Dispatch callbacks. This takes care of async dispatch. callbacks.on_predict_batch_end( @@ -731,7 +738,13 @@ def append_to_outputs(batch_outputs, outputs): self.jax_state_sync() callbacks.on_predict_end() self._jax_state = None - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + if outputs is None: + return None + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) + return outputs def train_on_batch( self, diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index fd8c276a86d..852a1d35e60 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -170,7 +170,13 @@ def fit( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -214,12 +220,19 @@ def append_to_outputs(batch_outputs, outputs): for begin_step, end_step, data in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + if outputs is None: + return None + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) + return outputs @traceback_utils.filter_traceback def evaluate( diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index ac2e64a8060..c3d5c81d26a 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -171,7 +171,13 @@ def fit( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -213,15 +219,22 @@ def append_to_outputs(batch_outputs, outputs): self.stop_predicting = False callbacks.on_predict_begin() outputs = None - for begin_step, end_step, data in epoch_iterator.enumerate_epoch(): + for begin_step, end_step, iterator in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) - batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + batch_outputs = self.predict_function(iterator) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + if outputs is None: + return None + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) + return outputs @traceback_utils.filter_traceback def evaluate( diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index cd6410999dd..c77bdac1a13 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -521,7 +521,13 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = TFEpochIterator( @@ -586,17 +592,22 @@ def get_data(iterator): callbacks.on_predict_batch_begin(begin_step) data = get_data(iterator) batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end( end_step, {"outputs": batch_outputs} ) if self.stop_predicting: break callbacks.on_predict_end() - outputs = tree.map_structure_up_to( - batch_outputs, potentially_ragged_concat, outputs - ) - return tree.map_structure(convert_to_np_if_not_ragged, outputs) + if accumulate: + if outputs is None: + return None + outputs = tree.map_structure_up_to( + batch_outputs, potentially_ragged_concat, outputs + ) + return tree.map_structure(convert_to_np_if_not_ragged, outputs) + return outputs def train_on_batch( self, diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index ad68c2f3a7e..71b1a33c07c 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -395,7 +395,13 @@ def evaluate( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): # Create an iterator that yields batches of input data. epoch_iterator = TorchEpochIterator( @@ -442,13 +448,20 @@ def append_to_outputs(batch_outputs, outputs): for begin_step, end_step, data in epoch_iterator: callbacks.on_predict_batch_begin(begin_step) batch_outputs = self.predict_function(data) - outputs = append_to_outputs(batch_outputs, outputs) + if accumulate: + outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) if self.stop_predicting: break callbacks.on_predict_end() - outputs = tree.map_structure(backend.convert_to_numpy, outputs) - return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + if accumulate: + if outputs is None: + return None + outputs = tree.map_structure(backend.convert_to_numpy, outputs) + return tree.map_structure_up_to( + batch_outputs, np.concatenate, outputs + ) + return outputs def train_on_batch( self, diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index bac422db249..8b5178891f0 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -807,7 +807,13 @@ def evaluate( raise NotImplementedError def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, + x, + batch_size=None, + verbose="auto", + steps=None, + callbacks=None, + accumulate=True, ): """Generates output predictions for the input samples. @@ -858,9 +864,14 @@ def predict( repeating dataset, it will run indefinitely. callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during prediction. + accumulate: Boolean. Whether to accumulate predictions in memory. + If `False`, predictions are not returned and must be handled + via callbacks to avoid memory issues with large datasets. + Defaults to `True`. Returns: - NumPy array(s) of predictions. + NumPy array(s) of predictions if `accumulate=True`, + otherwise `None`. """ raise NotImplementedError diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 05e910aa603..d0014e59d38 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -2207,6 +2207,93 @@ def test_predict_dropout(self): out3 = model.predict_on_batch(np.ones((2, 20))) self.assertGreater(5, np.sum(np.abs(out2 - out3))) + def test_predict_accumulate_parameter(self): + # Test that `predict` with accumulate=True/False works correctly + model = ExampleModel(units=3) + x = np.ones((10, 4)) + + # Test accumulate=True (default behavior) + outputs_accumulated = model.predict(x, batch_size=2, accumulate=True) + self.assertIsInstance(outputs_accumulated, np.ndarray) + self.assertEqual(outputs_accumulated.shape, (10, 3)) + self.assertAllClose(outputs_accumulated, 4 * np.ones((10, 3))) + + # Test accumulate=False with callback to capture outputs + class OutputCaptureCallback(Callback): + def __init__(self): + super().__init__() + self.outputs = [] + + def on_predict_batch_end(self, batch, logs=None): + if logs and "outputs" in logs: + self.outputs.append(logs["outputs"]) + + callback = OutputCaptureCallback() + outputs_none = model.predict( + x, batch_size=2, accumulate=False, callbacks=[callback] + ) + + # Verify accumulate=False returns None + self.assertIsNone(outputs_none) + + # Verify callback captured the correct number of batches + self.assertEqual( + len(callback.outputs), 5 + ) # 10 samples / 2 batch_size = 5 batches + + # Verify callback outputs match accumulated outputs when concatenated + concatenated_outputs = np.concatenate(callback.outputs, axis=0) + self.assertAllClose(outputs_accumulated, concatenated_outputs) + + def test_predict_accumulate_parameter_multi_output(self): + # Test accumulate parameter with multi-output model + inputs = layers.Input((4,)) + output1 = layers.Dense(3, name="out1")(inputs) + output2 = layers.Dense(2, name="out2")(inputs) + model = models.Model(inputs=inputs, outputs=[output1, output2]) + + x = np.ones((8, 4)) + + # Test accumulate=True (default behavior) + outputs_accumulated = model.predict(x, batch_size=2, accumulate=True) + self.assertIsInstance(outputs_accumulated, list) + self.assertEqual(len(outputs_accumulated), 2) + self.assertEqual(outputs_accumulated[0].shape, (8, 3)) + self.assertEqual(outputs_accumulated[1].shape, (8, 2)) + + # Test accumulate=False with callback + class OutputCaptureCallback(Callback): + def __init__(self): + super().__init__() + self.outputs = [] + + def on_predict_batch_end(self, batch, logs=None): + if logs and "outputs" in logs: + self.outputs.append(logs["outputs"]) + + callback = OutputCaptureCallback() + outputs_none = model.predict( + x, batch_size=2, accumulate=False, callbacks=[callback] + ) + + # Verify accumulate=False returns None + self.assertIsNone(outputs_none) + + # Verify callback captured the correct outputs + self.assertEqual( + len(callback.outputs), 4 + ) # 8 samples / 2 batch_size = 4 batches + + # Verify callback outputs match accumulated outputs when concatenated + concatenated_outputs_1 = np.concatenate( + [out[0] for out in callback.outputs], axis=0 + ) + concatenated_outputs_2 = np.concatenate( + [out[1] for out in callback.outputs], axis=0 + ) + self.assertAllClose(outputs_accumulated[0], concatenated_outputs_1) + self.assertAllClose(outputs_accumulated[1], concatenated_outputs_2) + @pytest.mark.requires_trainable_backend def test_recompile(self): model = ExampleModel(units=3)