diff --git a/ivtmetrics/recognition.py b/ivtmetrics/recognition.py index c785bf1..9320543 100755 --- a/ivtmetrics/recognition.py +++ b/ivtmetrics/recognition.py @@ -60,10 +60,7 @@ def __init__(self, num_class=100, ignore_null=False): self.reset_global() def resolve_nan(self, classwise): - equiv_nan = ['-0', '-0.', '-0.0', '-.0'] - classwise = list(map(str, classwise)) - classwise = [np.nan if x in equiv_nan else x for x in classwise] - classwise = np.array(list(map(float, classwise))) + classwise[classwise==-0.0] = np.nan return classwise ##%%%%%%%%%%%%%%%%%%% RESET OP #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% @@ -264,4 +261,4 @@ def aggregate_global_records_partial(self): if len(self.targets) > 0: global_targets.append(self.targets) global_predictions.append(self.predictions) - return global_targets, global_predictions \ No newline at end of file + return global_targets, global_predictions