@@ -259,17 +259,31 @@ def encoding_length(self) -> int:
259259 def input_shape (self ) -> Tuple [int , ...]:
260260 raise NotImplementedError
261261
262- def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> np .ndarray :
262+ def predict (self , x : np .ndarray , batch_size : int = 128 , training_mode : bool = False , ** kwargs ) -> np .ndarray :
263263 """
264264 Perform projections over a batch of encodings.
265265
266266 :param x: Encodings.
267267 :param batch_size: Batch size.
268+ :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
268269 :return: Array of prediction projections of shape `(num_inputs, nb_classes)`.
269270 """
270- logging .info ("Projecting new sample from z value" )
271- y = self ._model (x ).numpy ()
272- return y
271+ # Run prediction with batch processing
272+ results_list = []
273+ num_batch = int (np .ceil (len (x ) / float (batch_size )))
274+ for m in range (num_batch ):
275+ # Batch indexes
276+ begin , end = (
277+ m * batch_size ,
278+ min ((m + 1 ) * batch_size , x .shape [0 ]),
279+ )
280+
281+ # Run prediction
282+ results_list .append (self ._model (x [begin :end ], training = training_mode ).numpy ())
283+
284+ results = np .vstack (results_list )
285+
286+ return results
273287
274288 def loss_gradient (self , x , y , ** kwargs ) -> np .ndarray :
275289 raise NotImplementedError
0 commit comments