Skip to content

Commit 0a28c83

Browse files
author
Beat Buesser
committed
Add batch processing to TensorFlowGenerator.predict
Signed-off-by: Beat Buesser <[email protected]>
1 parent 9f75c05 commit 0a28c83

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

art/estimators/gan/tensorflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ def __init__(
6262

6363
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
6464
"""
65-
Generates a sample
65+
Generates a sample.
6666
67-
param x: a seed
68-
:return: the sample
67+
:param x: A input seed.
68+
:param batch_size: The batch size for predictions.
69+
:return: The generated sample.
6970
"""
70-
return self.generator.predict(x)
71+
return self.generator.predict(x, batch_size=batch_size, **kwargs)
7172

7273
@property
7374
def input_shape(self) -> Tuple[int, int]:

art/estimators/generation/tensorflow.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,31 @@ def encoding_length(self) -> int:
259259
def input_shape(self) -> Tuple[int, ...]:
260260
raise NotImplementedError
261261

262-
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
262+
def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs) -> np.ndarray:
263263
"""
264264
Perform projections over a batch of encodings.
265265
266266
:param x: Encodings.
267267
:param batch_size: Batch size.
268+
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
268269
:return: Array of prediction projections of shape `(num_inputs, nb_classes)`.
269270
"""
270-
logging.info("Projecting new sample from z value")
271-
y = self._model(x).numpy()
272-
return y
271+
# Run prediction with batch processing
272+
results_list = []
273+
num_batch = int(np.ceil(len(x) / float(batch_size)))
274+
for m in range(num_batch):
275+
# Batch indexes
276+
begin, end = (
277+
m * batch_size,
278+
min((m + 1) * batch_size, x.shape[0]),
279+
)
280+
281+
# Run prediction
282+
results_list.append(self._model(x[begin:end], training=training_mode).numpy())
283+
284+
results = np.vstack(results_list)
285+
286+
return results
273287

274288
def loss_gradient(self, x, y, **kwargs) -> np.ndarray:
275289
raise NotImplementedError

0 commit comments

Comments
 (0)