2020"""
2121from __future__ import absolute_import , division , print_function , unicode_literals
2222
23+ from functools import total_ordering
2324import logging
2425from typing import Callable , List , Optional , Union , Tuple , TYPE_CHECKING
2526
@@ -41,22 +42,25 @@ class BlackBoxClassifier(ClassifierMixin, BaseEstimator):
4142 Class for black-box classifiers.
4243 """
4344
44- estimator_params = Classifier .estimator_params + ["nb_classes" , "input_shape" , "predict " ]
45+ estimator_params = Classifier .estimator_params + ["nb_classes" , "input_shape" , "predict_fn " ]
4546
4647 def __init__ (
4748 self ,
48- predict_fn : Callable ,
49+ predict_fn : Union [ Callable , Tuple [ np . ndarray , np . ndarray ]] ,
4950 input_shape : Tuple [int , ...],
5051 nb_classes : int ,
5152 clip_values : Optional ["CLIP_VALUES_TYPE" ] = None ,
5253 preprocessing_defences : Union ["Preprocessor" , List ["Preprocessor" ], None ] = None ,
5354 postprocessing_defences : Union ["Postprocessor" , List ["Postprocessor" ], None ] = None ,
5455 preprocessing : "PREPROCESSING_TYPE" = (0.0 , 1.0 ),
56+ fuzzy_float_compare : bool = False ,
5557 ):
5658 """
5759 Create a `Classifier` instance for a black-box model.
5860
59- :param predict_fn: Function that takes in one input of the data and returns the one-hot encoded predicted class.
61+ :param predict_fn: Function that takes in an `np.ndarray` of input data and returns the one-hot encoded matrix
62+ of predicted classes or tuple of the form `(inputs, labels)` containing the predicted labels for each
63+ input.
6064 :param input_shape: Size of input.
6165 :param nb_classes: Number of prediction classes.
6266 :param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
@@ -68,6 +72,9 @@ def __init__(
6872 :param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
6973 used for data preprocessing. The first value will be subtracted from the input. The input will then
7074 be divided by the second one.
75+ :param fuzzy_float_compare: If `predict_fn` is a tuple mapping inputs to labels, and this is True, looking up
76+ inputs in the table will be done using `numpy.isclose`. Only set to True if really needed, since this
77+ severely affects performance.
7178 """
7279 super ().__init__ (
7380 model = None ,
@@ -76,8 +83,10 @@ def __init__(
7683 postprocessing_defences = postprocessing_defences ,
7784 preprocessing = preprocessing ,
7885 )
79-
80- self ._predict_fn = predict_fn
86+ if callable (predict_fn ):
87+ self ._predict_fn = predict_fn
88+ else :
89+ self ._predict_fn = _make_lookup_predict_fn (predict_fn , fuzzy_float_compare )
8190 self ._input_shape = input_shape
8291 self ._nb_classes = nb_classes
8392
@@ -161,24 +170,27 @@ class BlackBoxClassifierNeuralNetwork(NeuralNetworkMixin, ClassifierMixin, BaseE
161170 NeuralNetworkMixin .estimator_params
162171 + ClassifierMixin .estimator_params
163172 + BaseEstimator .estimator_params
164- + ["nb_classes" , "input_shape" , "predict " ]
173+ + ["nb_classes" , "input_shape" , "predict_fn " ]
165174 )
166175
167176 def __init__ (
168177 self ,
169- predict : Callable ,
178+ predict_fn : Union [ Callable , Tuple [ np . ndarray , np . ndarray ]] ,
170179 input_shape : Tuple [int , ...],
171180 nb_classes : int ,
172181 channels_first : bool = True ,
173182 clip_values : Optional ["CLIP_VALUES_TYPE" ] = None ,
174183 preprocessing_defences : Union ["Preprocessor" , List ["Preprocessor" ], None ] = None ,
175184 postprocessing_defences : Union ["Postprocessor" , List ["Postprocessor" ], None ] = None ,
176185 preprocessing : "PREPROCESSING_TYPE" = (0 , 1 ),
186+ fuzzy_float_compare : bool = False ,
177187 ):
178188 """
179189 Create a `Classifier` instance for a black-box model.
180190
181- :param predict: Function that takes in one input of the data and returns the one-hot encoded predicted class.
191+ :param predict_fn: Function that takes in an `np.ndarray` of input data and returns the one-hot encoded matrix
192+ of predicted classes or tuple of the form `(inputs, labels)` containing the predicted labels for each
193+ input.
182194 :param input_shape: Size of input.
183195 :param nb_classes: Number of prediction classes.
184196 :param channels_first: Set channels first or last.
@@ -191,6 +203,9 @@ def __init__(
191203 :param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
192204 used for data preprocessing. The first value will be subtracted from the input. The input will then
193205 be divided by the second one.
206+ :param fuzzy_float_compare: If `predict_fn` is a tuple mapping inputs to labels, and this is True, looking up
207+ inputs in the table will be done using `numpy.isclose`. Only set to True if really needed, since this
208+ severely affects performance.
194209 """
195210 super ().__init__ (
196211 model = None ,
@@ -201,7 +216,10 @@ def __init__(
201216 preprocessing = preprocessing ,
202217 )
203218
204- self ._predictions = predict
219+ if callable (predict_fn ):
220+ self ._predict_fn = predict_fn
221+ else :
222+ self ._predict_fn = _make_lookup_predict_fn (predict_fn , fuzzy_float_compare )
205223 self ._input_shape = input_shape
206224 self ._nb_classes = nb_classes
207225 self ._learning_phase = None
@@ -236,7 +254,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs):
236254 batch_index * batch_size ,
237255 min ((batch_index + 1 ) * batch_size , x_preprocessed .shape [0 ]),
238256 )
239- predictions [begin :end ] = self ._predictions (x_preprocessed [begin :end ])
257+ predictions [begin :end ] = self ._predict_fn (x_preprocessed [begin :end ])
240258
241259 # Apply postprocessing
242260 predictions = self ._apply_postprocessing (preds = predictions , fit = False )
@@ -287,3 +305,89 @@ def loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
287305
288306 def compute_loss (self , x : np .ndarray , y : np .ndarray , ** kwargs ) -> np .ndarray :
289307 raise NotImplementedError
308+
309+
310+ @total_ordering
311+ class FuzzyMapping :
312+ """
313+ Class for a sample/label pair to be used in a `SortedList`.
314+ """
315+
316+ def __init__ (self , key : np .ndarray , value = None ):
317+ """
318+ Create an instance of a key/value to pair to be used in a `SortedList`.
319+
320+ :param key: The sample to be matched against.
321+ :param value: The mapped value.
322+ """
323+ self .key = key
324+ self .value = value
325+
326+ def __eq__ (self , other ):
327+ return np .all (np .isclose (self .key , other .key ))
328+
329+ def __ge__ (self , other ):
330+ # This implements >= comparison so we can use this class in a `SortedList`. The `total_ordering` decorator
331+ # automatically generates the rest of the comparison magic functions based on this one
332+
333+ close_cells = np .isclose (self .key , other .key )
334+ if np .all (close_cells ):
335+ return True
336+
337+ # If the keys are not exactly the same (up to floating-point inaccuracies), we compare the value of the first
338+ # index which is not the same to decide on an ordering
339+
340+ compare_idx = np .unravel_index (np .argmin (close_cells ), shape = self .key .shape )
341+ return self .key [compare_idx ] >= other .key [compare_idx ]
342+
343+
344+ def _make_lookup_predict_fn (existing_predictions : Tuple [np .ndarray , np .ndarray ], fuzzy_float_compare : bool ) -> Callable :
345+ """
346+ Makes a predict_fn callback based on a table of existing predictions.
347+
348+ :param existing_predictions: Tuple of (samples, labels).
349+ :param fuzzy_float_compare: Look up predictions using `np.isclose`, only set to True if really needed, since this
350+ severely affects performance.
351+ :return: Prediction function.
352+ """
353+
354+ samples , labels = existing_predictions
355+
356+ if fuzzy_float_compare :
357+ from sortedcontainers import SortedList
358+
359+ # Construct a search-tree of the predictions, using fuzzy float comparison
360+ sorted_predictions = SortedList ([FuzzyMapping (key , value ) for key , value in zip (samples , labels )])
361+
362+ def fuzzy_predict_fn (batch ):
363+ predictions = []
364+ for row in batch :
365+ try :
366+ match_idx = sorted_predictions .index (FuzzyMapping (row ))
367+ except ValueError as err :
368+ raise ValueError ("No existing prediction for queried input" ) from err
369+
370+ predictions .append (sorted_predictions [match_idx ].value )
371+
372+ return np .array (predictions )
373+
374+ return fuzzy_predict_fn
375+
376+ # Construct a dictionary to map from samples to predictions. We use the bytes of the `ndarray` as the key,
377+ # because the `ndarray` itself is not hashable
378+ mapping = dict ()
379+ for x , y in zip (samples , labels ):
380+ mapping [x .tobytes ()] = y
381+
382+ def predict_fn (batch ):
383+ predictions = []
384+ for row in batch :
385+ row_bytes = row .tobytes ()
386+ if row .tobytes () not in mapping :
387+ raise ValueError ("No existing prediction for queried input" )
388+
389+ predictions .append (mapping [row_bytes ])
390+
391+ return np .array (predictions )
392+
393+ return predict_fn
0 commit comments