@@ -114,10 +114,16 @@ def predict(self, inputs, logits=False):
114114 # Apply defences
115115 inputs = self ._apply_defences_predict (inputs )
116116
117- preds = self ._preds ([inputs ])[0 ]
118- if not logits :
119- exp = np .exp (preds - np .max (preds , axis = 1 , keepdims = True ))
120- preds = exp / np .sum (exp , axis = 1 , keepdims = True )
117+ # Run predictions with batching
118+ batch_size = 512
119+ preds = np .zeros ((inputs .shape [0 ], self .nb_classes ), dtype = np .float32 )
120+ for b in range (inputs .shape [0 ] // batch_size + 1 ):
121+ begin , end = b * batch_size , min ((b + 1 ) * batch_size , inputs .shape [0 ])
122+ preds [begin :end ] = self ._preds ([inputs [begin :end ]])[0 ]
123+
124+ if not logits :
125+ exp = np .exp (preds [begin :end ] - np .max (preds [begin :end ], axis = 1 , keepdims = True ))
126+ preds [begin :end ] = exp / np .sum (exp , axis = 1 , keepdims = True )
121127
122128 return preds
123129
@@ -141,13 +147,13 @@ def fit(self, inputs, outputs, batch_size=128, nb_epochs=20):
141147 # Apply defences
142148 inputs , outputs = self ._apply_defences_fit (inputs , outputs )
143149
144- gen = generator (inputs , outputs , batch_size )
150+ gen = generator_fit (inputs , outputs , batch_size )
145151 self ._model .fit_generator (gen , steps_per_epoch = inputs .shape [0 ] / batch_size , epochs = nb_epochs )
146152
147153
148- def generator (data , labels , batch_size = 128 ):
154+ def generator_fit (data , labels , batch_size = 128 ):
149155 """
150- Minimal data generator for batching large datasets.
156+ Minimal data generator for randomly batching large datasets.
151157
152158 :param data: The data sample to batch.
153159 :type data: `np.ndarray`
@@ -160,4 +166,4 @@ def generator(data, labels, batch_size=128):
160166 """
161167 while True :
162168 indices = np .random .randint (data .shape [0 ], size = batch_size )
163- yield data [indices ], labels [indices ]
169+ yield data [indices ], labels [indices ]
0 commit comments