Skip to content

Commit c3a7ac5

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Add batching for Keras prediction
1 parent 415e819 commit c3a7ac5

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

art/classifiers/keras.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,16 @@ def predict(self, inputs, logits=False):
114114
# Apply defences
115115
inputs = self._apply_defences_predict(inputs)
116116

117-
preds = self._preds([inputs])[0]
118-
if not logits:
119-
exp = np.exp(preds - np.max(preds, axis=1, keepdims=True))
120-
preds = exp / np.sum(exp, axis=1, keepdims=True)
117+
# Run predictions with batching
118+
batch_size = 512
119+
preds = np.zeros((inputs.shape[0], self.nb_classes), dtype=np.float32)
120+
for b in range(inputs.shape[0] // batch_size + 1):
121+
begin, end = b * batch_size, min((b + 1) * batch_size, inputs.shape[0])
122+
preds[begin:end] = self._preds([inputs[begin:end]])[0]
123+
124+
if not logits:
125+
exp = np.exp(preds[begin:end] - np.max(preds[begin:end], axis=1, keepdims=True))
126+
preds[begin:end] = exp / np.sum(exp, axis=1, keepdims=True)
121127

122128
return preds
123129

@@ -141,13 +147,13 @@ def fit(self, inputs, outputs, batch_size=128, nb_epochs=20):
141147
# Apply defences
142148
inputs, outputs = self._apply_defences_fit(inputs, outputs)
143149

144-
gen = generator(inputs, outputs, batch_size)
150+
gen = generator_fit(inputs, outputs, batch_size)
145151
self._model.fit_generator(gen, steps_per_epoch=inputs.shape[0] / batch_size, epochs=nb_epochs)
146152

147153

148-
def generator(data, labels, batch_size=128):
154+
def generator_fit(data, labels, batch_size=128):
149155
"""
150-
Minimal data generator for batching large datasets.
156+
Minimal data generator for randomly batching large datasets.
151157
152158
:param data: The data sample to batch.
153159
:type data: `np.ndarray`
@@ -160,4 +166,4 @@ def generator(data, labels, batch_size=128):
160166
"""
161167
while True:
162168
indices = np.random.randint(data.shape[0], size=batch_size)
163-
yield data[indices], labels[indices]
169+
yield data[indices], labels[indices]

0 commit comments

Comments
 (0)