2020"""
2121from __future__ import absolute_import , division , print_function , unicode_literals
2222
23+ from functools import total_ordering
2324import logging
2425from typing import Callable , List , Optional , Union , Tuple , TYPE_CHECKING
2526
26- from functools import total_ordering
27-
2827import numpy as np
2928
3029from art .estimators .estimator import BaseEstimator , NeuralNetworkMixin
@@ -43,7 +42,7 @@ class BlackBoxClassifier(ClassifierMixin, BaseEstimator):
4342 Wrapper class for black-box classifiers.
4443 """
4544
46- estimator_params = Classifier .estimator_params + ["nb_classes" , "input_shape" , "predict " ]
45+ estimator_params = Classifier .estimator_params + ["nb_classes" , "input_shape" , "predict_fn " ]
4746
4847 def __init__ (
4948 self ,
@@ -171,7 +170,7 @@ class BlackBoxClassifierNeuralNetwork(NeuralNetworkMixin, ClassifierMixin, BaseE
171170 NeuralNetworkMixin .estimator_params
172171 + ClassifierMixin .estimator_params
173172 + BaseEstimator .estimator_params
174- + ["nb_classes" , "input_shape" , "predict " ]
173+ + ["nb_classes" , "input_shape" , "predict_fn " ]
175174 )
176175
177176 def __init__ (
@@ -309,14 +308,14 @@ def compute_loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
309308
310309
311310@total_ordering
312- class FuzzyMappingWrapper :
311+ class FuzzyMapping :
313312 """
314- Wrapper class for a sample, label pair to be used in a `SortedList`.
313+ Class for a sample/ label pair to be used in a `SortedList`.
315314 """
316315
317316 def __init__ (self , key : np .ndarray , value = None ):
318317 """
319- Create a wrapper for the key, value to pair to be used in a `SortedList`.
318+ Create an instance of a key/ value to pair to be used in a `SortedList`.
320319
321320 :param key: The sample to be matched against.
322321 :param value: The mapped value.
@@ -327,20 +326,18 @@ def __init__(self, key: np.ndarray, value=None):
327326 def __eq__ (self , other ):
328327 return np .all (np .isclose (self .key , other .key ))
329328
330- def __ne__ (self , other ):
331- return not self .__eq__ (other )
332-
333329 def __ge__ (self , other ):
334330 # This implements >= comparison so we can use this class in a `SortedList`. The `total_ordering` decorator
335331 # automatically generates the rest of the comparison magic functions based on this one
336332
337- if self .__eq__ (other ):
333+ close_cells = np .isclose (self .key , other .key )
334+ if np .all (close_cells ):
338335 return True
339336
340337 # If the keys are not exactly the same (up to floating-point inaccuracies), we compare the value of the first
341338 # index which is not the same to decide on an ordering
342339
343- compare_idx = np .unravel_index (np .argmin (np . isclose ( self . key , other . key ) ), shape = self .key .shape )
340+ compare_idx = np .unravel_index (np .argmin (close_cells ), shape = self .key .shape )
344341 return self .key [compare_idx ] >= other .key [compare_idx ]
345342
346343
@@ -360,13 +357,13 @@ def _make_lookup_predict_fn(existing_predictions: Tuple[np.ndarray, np.ndarray],
360357 from sortedcontainers import SortedList
361358
362359 # Construct a search-tree of the predictions, using fuzzy float comparison
363- sorted_predictions = SortedList ([FuzzyMappingWrapper (key , value ) for key , value in zip (samples , labels )])
360+ sorted_predictions = SortedList ([FuzzyMapping (key , value ) for key , value in zip (samples , labels )])
364361
365362 def fuzzy_predict_fn (batch ):
366363 predictions = []
367364 for row in batch :
368365 try :
369- match_idx = sorted_predictions .index (FuzzyMappingWrapper (row ))
366+ match_idx = sorted_predictions .index (FuzzyMapping (row ))
370367 except ValueError as err :
371368 raise ValueError ("No existing prediction for queried input" ) from err
372369
0 commit comments