Skip to content

Commit 23486b6

Browse files
Lydia Chanfacebook-github-bot
authored andcommitted
Increase limit on number of detections per image in {COCO,LVIS}Evaluator
Summary: ## Context - The current limit on the number of detections per image (`K`) in LVIS is 300. - Implementing AP_pool/AP_fixed requires removing this default limit on `K` - [Literature](https://arxiv.org/pdf/2102.01066.pdf) has shown that increasing `K` correlates with AP gains ## This Diff - Changed limit on number of detections per image (`K`) to be customizable for LVIS and COCO through `TEST.DETECTIONS_PER_IMAGE` in the config - For COCO: - Maintain the default `max_dets_per_image` to be [1, 10, 100] as from [COCOEval](https://www.internalfb.com/code/fbsource/[88bb57c3054a]/fbcode/deeplearning/projects/cocoApi/PythonAPI/pycocotools/cocoeval.py?lines=28-29) - Allow users to input a custom integer for `TEST.DETECTIONS_PER_IMAGE` in the config, and use [1, 10, `TEST.DETECTIONS_PER_IMAGE`] for COCOEval - For LVIS: - Maintain the default `max_dets_per_image` to be 300 as from [LVISEval](https://www.internalfb.com/code/fbsource/[f6b86d023721]/fbcode/deeplearning/projects/lvisApi/lvis/eval.py?lines=528-529) - Allow users to input a custom integer for `TEST.DETECTIONS_PER_IMAGE` in the config, and use this in LVISEval - Added `COCOevalMaxDets` for evaluating AP with the custom limit on number of detections per image (since default `COCOeval` uses 100 as limit on detections per image for evaluating AP) ## Inference Runs using this Diff - Performed inference using `K = {300, 1000, 10000, 100000}` - Launched fblearner flows for object detector baseline models with N1055536 (LVIS) and N1055756 (COCO) - Recorded [results of running inference](https://docs.google.com/spreadsheets/d/1rgdjN2KvxcYfKCkGUC4tMw0XQJ5oZL0dwjOIh84YRg8/edit?usp=sharing) Reviewed By: ppwwyyxx Differential Revision: D30077359 fbshipit-source-id: 372eb5e0d7c228fb77fe23bf80d53597ec66287b
1 parent 55c0078 commit 23486b6

File tree

2 files changed

+155
-6
lines changed

2 files changed

+155
-6
lines changed

detectron2/evaluation/coco_evaluation.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
distributed=True,
4848
output_dir=None,
4949
*,
50+
max_dets_per_image=None,
5051
use_fast_impl=True,
5152
kpt_oks_sigmas=(),
5253
):
@@ -71,6 +72,10 @@ def __init__(
7172
1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
7273
contains all the results in the format they are produced by the model.
7374
2. "coco_instances_results.json" a json file in COCO's result format.
75+
max_dets_per_image (int): limit on the maximum number of detections per image.
76+
By default in COCO, this limit is to 100, but this can be customized
77+
to be greater, as is needed in evaluation metrics AP fixed and AP pool
78+
(see https://arxiv.org/pdf/2102.01066.pdf)
7479
use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
7580
Although the results should be very close to the official implementation in COCO
7681
API, it is still recommended to compute results with the official API for use in
@@ -85,6 +90,17 @@ def __init__(
8590
self._output_dir = output_dir
8691
self._use_fast_impl = use_fast_impl
8792

93+
# COCOeval requires the limit on the number of detections per image (maxDets) to be a list
94+
# with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the
95+
# 3rd element (100) is used as the limit on the number of detections per image when
96+
# evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval,
97+
# we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults.
98+
if max_dets_per_image is None:
99+
max_dets_per_image = [1, 10, 100]
100+
else:
101+
max_dets_per_image = [1, 10, max_dets_per_image]
102+
self._max_dets_per_image = max_dets_per_image
103+
88104
if tasks is not None and isinstance(tasks, CfgNode):
89105
kpt_oks_sigmas = (
90106
tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas
@@ -239,6 +255,7 @@ def _eval_predictions(self, predictions, img_ids=None):
239255
kpt_oks_sigmas=self._kpt_oks_sigmas,
240256
use_fast_impl=self._use_fast_impl,
241257
img_ids=img_ids,
258+
max_dets_per_image=self._max_dets_per_image,
242259
)
243260
if len(coco_results) > 0
244261
else None # cocoapi does not handle empty results very well
@@ -533,7 +550,13 @@ def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area
533550

534551

535552
def _evaluate_predictions_on_coco(
536-
coco_gt, coco_results, iou_type, kpt_oks_sigmas=None, use_fast_impl=True, img_ids=None
553+
coco_gt,
554+
coco_results,
555+
iou_type,
556+
kpt_oks_sigmas=None,
557+
use_fast_impl=True,
558+
img_ids=None,
559+
max_dets_per_image=None,
537560
):
538561
"""
539562
Evaluate the coco results using COCOEval API.
@@ -551,6 +574,19 @@ def _evaluate_predictions_on_coco(
551574

552575
coco_dt = coco_gt.loadRes(coco_results)
553576
coco_eval = (COCOeval_opt if use_fast_impl else COCOeval)(coco_gt, coco_dt, iou_type)
577+
# For COCO, the default max_dets_per_image is [1, 10, 100].
578+
if max_dets_per_image is None:
579+
max_dets_per_image = [1, 10, 100] # Default from COCOEval
580+
else:
581+
assert (
582+
len(max_dets_per_image) >= 3
583+
), "COCOeval requires maxDets (and max_dets_per_image) to have length at least 3"
584+
# In the case that user supplies a custom input for max_dets_per_image,
585+
# apply COCOevalMaxDets to evaluate AP with the custom input.
586+
if max_dets_per_image[2] != 100:
587+
coco_eval = COCOevalMaxDets(coco_gt, coco_dt, iou_type)
588+
coco_eval.params.maxDets = max_dets_per_image
589+
554590
if img_ids is not None:
555591
coco_eval.params.imgIds = img_ids
556592

@@ -577,3 +613,94 @@ def _evaluate_predictions_on_coco(
577613
coco_eval.summarize()
578614

579615
return coco_eval
616+
617+
618+
class COCOevalMaxDets(COCOeval):
619+
"""
620+
Modified version of COCOeval for evaluating AP with a custom
621+
maxDets (by default for COCO, maxDets is 100)
622+
"""
623+
624+
def summarize(self):
625+
"""
626+
Compute and display summary metrics for evaluation results given
627+
a custom value for max_dets_per_image
628+
"""
629+
630+
def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
631+
p = self.params
632+
iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
633+
titleStr = "Average Precision" if ap == 1 else "Average Recall"
634+
typeStr = "(AP)" if ap == 1 else "(AR)"
635+
iouStr = (
636+
"{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
637+
if iouThr is None
638+
else "{:0.2f}".format(iouThr)
639+
)
640+
641+
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
642+
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
643+
if ap == 1:
644+
# dimension of precision: [TxRxKxAxM]
645+
s = self.eval["precision"]
646+
# IoU
647+
if iouThr is not None:
648+
t = np.where(iouThr == p.iouThrs)[0]
649+
s = s[t]
650+
s = s[:, :, :, aind, mind]
651+
else:
652+
# dimension of recall: [TxKxAxM]
653+
s = self.eval["recall"]
654+
if iouThr is not None:
655+
t = np.where(iouThr == p.iouThrs)[0]
656+
s = s[t]
657+
s = s[:, :, aind, mind]
658+
if len(s[s > -1]) == 0:
659+
mean_s = -1
660+
else:
661+
mean_s = np.mean(s[s > -1])
662+
print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
663+
return mean_s
664+
665+
def _summarizeDets():
666+
stats = np.zeros((12,))
667+
# Evaluate AP using the custom limit on maximum detections per image
668+
stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
669+
stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
670+
stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
671+
stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
672+
stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
673+
stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
674+
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
675+
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
676+
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
677+
stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
678+
stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
679+
stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
680+
return stats
681+
682+
def _summarizeKps():
683+
stats = np.zeros((10,))
684+
stats[0] = _summarize(1, maxDets=20)
685+
stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
686+
stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
687+
stats[3] = _summarize(1, maxDets=20, areaRng="medium")
688+
stats[4] = _summarize(1, maxDets=20, areaRng="large")
689+
stats[5] = _summarize(0, maxDets=20)
690+
stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
691+
stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
692+
stats[8] = _summarize(0, maxDets=20, areaRng="medium")
693+
stats[9] = _summarize(0, maxDets=20, areaRng="large")
694+
return stats
695+
696+
if not self.eval:
697+
raise Exception("Please run accumulate() first")
698+
iouType = self.params.iouType
699+
if iouType == "segm" or iouType == "bbox":
700+
summarize = _summarizeDets
701+
elif iouType == "keypoints":
702+
summarize = _summarizeKps
703+
self.stats = summarize()
704+
705+
def __str__(self):
706+
self.summarize()

detectron2/evaluation/lvis_evaluation.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@ class LVISEvaluator(DatasetEvaluator):
2525
LVIS's metrics and evaluation API.
2626
"""
2727

28-
def __init__(self, dataset_name, tasks=None, distributed=True, output_dir=None):
28+
def __init__(
29+
self,
30+
dataset_name,
31+
tasks=None,
32+
distributed=True,
33+
output_dir=None,
34+
*,
35+
max_dets_per_image=None,
36+
):
2937
"""
3038
Args:
3139
dataset_name (str): name of the dataset to be evaluated.
@@ -37,6 +45,8 @@ def __init__(self, dataset_name, tasks=None, distributed=True, output_dir=None):
3745
distributed (True): if True, will collect results from all ranks for evaluation.
3846
Otherwise, will evaluate the results in the current process.
3947
output_dir (str): optional, an output directory to dump results.
48+
max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP
49+
This limit, by default of the LVIS dataset, is 300.
4050
"""
4151
from lvis import LVIS
4252

@@ -53,6 +63,7 @@ def __init__(self, dataset_name, tasks=None, distributed=True, output_dir=None):
5363

5464
self._distributed = distributed
5565
self._output_dir = output_dir
66+
self._max_dets_per_image = max_dets_per_image
5667

5768
self._cpu_device = torch.device("cpu")
5869

@@ -158,7 +169,11 @@ def _eval_predictions(self, predictions):
158169
self._logger.info("Evaluating predictions ...")
159170
for task in sorted(tasks):
160171
res = _evaluate_predictions_on_lvis(
161-
self._lvis_api, lvis_results, task, class_names=self._metadata.get("thing_classes")
172+
self._lvis_api,
173+
lvis_results,
174+
task,
175+
max_dets_per_image=self._max_dets_per_image,
176+
class_names=self._metadata.get("thing_classes"),
162177
)
163178
self._results[task] = res
164179

@@ -313,11 +328,14 @@ def _evaluate_box_proposals(dataset_predictions, lvis_api, thresholds=None, area
313328
}
314329

315330

316-
def _evaluate_predictions_on_lvis(lvis_gt, lvis_results, iou_type, class_names=None):
331+
def _evaluate_predictions_on_lvis(
332+
lvis_gt, lvis_results, iou_type, max_dets_per_image=None, class_names=None
333+
):
317334
"""
318335
Args:
319336
iou_type (str):
320-
kpt_oks_sigmas (list[float]):
337+
max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP
338+
This limit, by default of the LVIS dataset, is 300.
321339
class_names (None or list[str]): if provided, will use it to predict
322340
per-category AP.
323341
@@ -344,9 +362,13 @@ def _evaluate_predictions_on_lvis(lvis_gt, lvis_results, iou_type, class_names=N
344362
for c in lvis_results:
345363
c.pop("bbox", None)
346364

365+
if max_dets_per_image is None:
366+
max_dets_per_image = 300 # Default for LVIS dataset
367+
347368
from lvis import LVISEval, LVISResults
348369

349-
lvis_results = LVISResults(lvis_gt, lvis_results)
370+
logger.info(f"Evaluating with max detections per image = {max_dets_per_image}")
371+
lvis_results = LVISResults(lvis_gt, lvis_results, max_dets=max_dets_per_image)
350372
lvis_eval = LVISEval(lvis_gt, lvis_results, iou_type)
351373
lvis_eval.run()
352374
lvis_eval.print_results()

0 commit comments

Comments
 (0)