diff --git a/tidecv/ap.py b/tidecv/ap.py index 0b34931f..97c99516 100644 --- a/tidecv/ap.py +++ b/tidecv/ap.py @@ -149,6 +149,9 @@ def get_mAP(self) -> float: aps = [x.get_ap() for x in self.objs.values() if not x.is_empty()] return sum(aps) / len(aps) + def get_per_class_APs(self) -> dict: + return {k : v.get_ap() for k, v in self.objs.items()} + def get_gt_positives(self) -> dict: return {k: v.num_gt_positives for k, v in self.objs.items()} diff --git a/tidecv/functions.py b/tidecv/functions.py index cc0ccd89..d17d7f91 100644 --- a/tidecv/functions.py +++ b/tidecv/functions.py @@ -87,8 +87,11 @@ def toRLE(mask:object, w:int, h:int): if type(mask) == list: # polygon -- a single object might consist of multiple parts # we merge all parts into one mask rle code - rles = maskUtils.frPyObjects(mask, h, w) - return maskUtils.merge(rles) + if mask: + rles = maskUtils.frPyObjects(mask, h, w) + return maskUtils.merge(rles) + else: + return mask elif type(mask['counts']) == list: # uncompressed RLE return maskUtils.frPyObjects(mask, h, w) diff --git a/tidecv/quantify.py b/tidecv/quantify.py index 9aa074be..b1ebb1b4 100644 --- a/tidecv/quantify.py +++ b/tidecv/quantify.py @@ -138,6 +138,7 @@ def __init__(self, gt:Data, preds:Data, pos_thresh:float, bg_thresh:float, mode: self.preds = preds self.errors = [] + self.per_class_errors = [] self.error_dict = {_type: [] for _type in TIDE._error_types} self.ap_data = ClassedAPDataObject() self.qualifiers = {} @@ -177,6 +178,7 @@ def _run(self): error.disabled = False self.ap = self.ap_data.get_mAP() + self.per_classes_ap = self.ap_data.get_per_class_APs() # Now that we've stored the fixed errors, we can clear the gt info self._clear() @@ -330,6 +332,39 @@ def fix_errors(self, condition=lambda x: False, transform=None, false_neg_dict:d return new_ap_data + def fix_main_per_class_errors(self, progressive: bool = False, error_types: list = None, qual: Qualifier = None) -> dict: + ap_data = self.ap_data + last_per_class_ap = self.per_classes_ap + + if qual is None: + qual = Qualifier('', None) + + if error_types is None: + error_types = TIDE._error_types + + errors_per_class = {} + for error in error_types: + _ap_data = self.fix_errors(qual._make_error_func(error), + ap_data=ap_data, disable_errors=progressive) + + new_per_class_ap = _ap_data.get_per_class_APs() + # If an error is negative that means it's likely due to binning differences, so just + # Ignore the negative by setting it to 0. + errors_per_class[error] = {k: max(new_per_class_ap[k] - last_per_class_ap[k], 0) + for k in new_per_class_ap.keys() + } + + if progressive: + last_per_class_ap = new_per_class_ap + ap_data = _ap_data + + # TODO: progressive + if progressive: + for error in self.errors: + error.disabled = False + + return errors_per_class + def fix_main_errors(self, progressive:bool=False, error_types:list=None, qual:Qualifier=None) -> dict: ap_data = self.ap_data last_ap = self.ap @@ -341,7 +376,6 @@ def fix_main_errors(self, progressive:bool=False, error_types:list=None, qual:Qu error_types = TIDE._error_types errors = {} - for error in error_types: _ap_data = self.fix_errors(qual._make_error_func(error), ap_data=ap_data, disable_errors=progressive) @@ -350,7 +384,7 @@ def fix_main_errors(self, progressive:bool=False, error_types:list=None, qual:Qu # If an error is negative that means it's likely due to binning differences, so just # Ignore the negative by setting it to 0. errors[error] = max(new_ap - last_ap, 0) - + if progressive: last_ap = new_ap ap_data = _ap_data @@ -435,6 +469,7 @@ def __init__(self, pos_threshold:float=0.5, background_threshold:float=0.1, mode self.runs = {} self.run_thresholds = {} self.run_main_errors = {} + self.run_main_per_class_errors = {} self.run_special_errors = {} self.qualifiers = OrderedDict() @@ -492,6 +527,7 @@ def add_qualifiers(self, *quals): def summarize(self): """ Summarizes the mAP values and errors for all runs in this TIDE object. Results are printed to the console. """ main_errors = self.get_main_errors() + main_per_class_errors = self.get_main_per_class_errors() special_errors = self.get_special_errors() for run_name, run in self.runs.items(): @@ -552,6 +588,15 @@ def summarize(self): [' dAP'] + ['{:6.2f}'.format(main_errors[run_name][err.short_name]) for err in TIDE._error_types] ], title='Main Errors') + print() + # Print the per class errors + P.print_table( + [['class'] + ['Type'] + [err.short_name for err in TIDE._error_types]] + + + [[run.gt.classes[k]] + [' dAP'] + ['{:6.2f}'.format(main_per_class_errors[run_name][err.short_name][k]) + for err in TIDE._error_types] + for k in sorted(main_per_class_errors[run_name][TIDE._error_types[0].short_name].keys())] + , title='Main Per Class Errors') print() @@ -605,6 +650,20 @@ def get_main_errors(self): return errors + def get_main_per_class_errors(self): + errors = {} + + for run_name, run in self.runs.items(): + if run_name in self.run_main_per_class_errors: + errors[run_name] = self.run_main_per_class_errors[run_name] + else: + errors[run_name] = { + error.short_name: value + for error, value in run.fix_main_per_class_errors().items() + } + + return errors + def get_special_errors(self): errors = {} @@ -631,4 +690,25 @@ def get_all_errors(self): 'special': self.get_special_errors() } + def get_confusion_matrix(self): + confusion_matrix = {} + for run_name, run in self.runs.items(): + n_classes = len(run.gt.classes) + #row: predicted classes, col: actual classes + cm = np.zeros((n_classes, n_classes), dtype=np.int32) + for error in run.errors: + if isinstance(error, ClassError): + cm[error.pred['class']-1][error.gt['class']-1] += 1 + confusion_matrix[run_name] = cm + sorted_keys = sorted(run.gt.classes.keys()) + print() + + P.print_table([ + ['pred/gt'] + [run.gt.classes[k] for k in sorted_keys], + ] + [ + [run.gt.classes[k]] + [str(cnt) for cnt in cm[i]] for i, k in enumerate(sorted_keys) + ], title=f"{run_name} confusion matrix") + + return confusion_matrix +