@@ -176,11 +176,32 @@ def compute_polygon_iou(
176176 return ious
177177
178178
179- def rank_pairs (sorted_pairs : NDArray [np .float64 ]):
179+ def rank_pairs (
180+ sorted_pairs : NDArray [np .float64 ],
181+ ) -> tuple [NDArray [np .float64 ], NDArray [np .intp ]]:
180182 """
181183 Prunes and ranks prediction pairs.
182184
183185 Should result in a single pair per prediction annotation.
186+
187+ Parameters
188+ ----------
189+ sorted_pairs : NDArray[np.float64]
190+ Ranked annotation pairs.
191+ Index 0 - Datum Index
192+ Index 1 - GroundTruth Index
193+ Index 2 - Prediction Index
194+ Index 3 - GroundTruth Label Index
195+ Index 4 - Prediction Label Index
196+ Index 5 - IOU
197+ Index 6 - Score
198+
199+ Returns
200+ -------
201+ NDArray[float64]
202+ Ranked prediction pairs.
203+ NDArray[intp]
204+ Indices of ranked prediction pairs.
184205 """
185206
186207 # remove unmatched ground truths
@@ -197,8 +218,10 @@ def rank_pairs(sorted_pairs: NDArray[np.float64]):
197218 pairs = pairs [mask_label_match | mask_unmatched_predictions ]
198219 indices = indices [mask_label_match | mask_unmatched_predictions ]
199220
200- # only keep the highest ranked pair
201- _ , unique_indices = np .unique (pairs [:, [0 , 2 ]], axis = 0 , return_index = True )
221+ # only keep the highest ranked prediction (datum_id, prediction_id, predicted_label_id)
222+ _ , unique_indices = np .unique (
223+ pairs [:, [0 , 2 , 4 ]], axis = 0 , return_index = True
224+ )
202225 pairs = pairs [unique_indices ]
203226 indices = indices [unique_indices ]
204227
@@ -216,55 +239,57 @@ def rank_pairs(sorted_pairs: NDArray[np.float64]):
216239
217240
218241def calculate_ranking_boundaries (
219- ranked_pairs : NDArray [np .float64 ], number_of_labels : int
220- ):
221- dt_gt_ids = ranked_pairs [:, (0 , 1 )].astype (np .int64 )
222- gt_ids = dt_gt_ids [:, 1 ]
223- ious = ranked_pairs [:, 5 ]
242+ ranked_pairs : NDArray [np .float64 ],
243+ ) -> NDArray [np .float64 ]:
244+ """
245+ Determine IOU boundaries for computing AP across chunks.
224246
225- unique_gts , gt_counts = np .unique (
226- dt_gt_ids ,
227- return_counts = True ,
228- axis = 0 ,
229- )
230- unique_gts = unique_gts [gt_counts > 1 ] # select gts with many pairs
231- unique_gts = unique_gts [unique_gts [:, 1 ] >= 0 ] # remove null
247+ Parameters
248+ ----------
249+ ranked_pairs : NDArray[np.float64]
250+ Ranked annotation pairs.
251+ Index 0 - Datum Index
252+ Index 1 - GroundTruth Index
253+ Index 2 - Prediction Index
254+ Index 3 - GroundTruth Label Index
255+ Index 4 - Prediction Label Index
256+ Index 5 - IOU
257+ Index 6 - Score
258+
259+ Returns
260+ -------
261+ NDArray[np.float64]
262+ A 1-D array containing the lower IOU boundary for classifying pairs as true-positive across chunks.
263+ """
264+ # groundtruths defined as (datum_id, groundtruth_id, groundtruth_label_id)
265+ gts = ranked_pairs [:, (0 , 1 , 3 )].astype (np .int64 )
266+ ious = ranked_pairs [:, 5 ]
232267
233- winning_predictions = np .ones_like (ious , dtype = np .bool_ )
234- winning_predictions [gt_ids < 0 ] = False # null gts cannot be won
235- iou_boundary = np .zeros_like (ious )
268+ iou_boundary = np .ones_like (ious ) * 2 # impossible bound
236269
270+ mask_valid_gts = gts [:, 1 ] >= 0
271+ unique_gts = np .unique (gts [mask_valid_gts ], axis = 0 )
237272 for gt in unique_gts :
238- mask_gts = (
239- ranked_pairs [:, (0 , 1 )].astype (np .int64 ) == (gt [0 ], gt [1 ])
240- ).all (axis = 1 )
241- for label in range (number_of_labels ):
242- mask_plabel = (ranked_pairs [:, 4 ] == label ) & mask_gts
243- if mask_plabel .sum () <= 1 :
244- continue
273+ mask_gt = (gts == gt ).all (axis = 1 )
274+ if mask_gt .sum () <= 1 :
275+ iou_boundary [mask_gt ] = 0.0
276+ continue
245277
246- # mark sequence of increasing IOUs starting from index 0
247- labeled_ious = ranked_pairs [mask_plabel , 5 ]
248- mask_increasing_iou = np .ones_like (labeled_ious , dtype = np .bool_ )
249- mask_increasing_iou [1 :] = labeled_ious [1 :] > labeled_ious [:- 1 ]
250- idx_dec = np .where (~ mask_increasing_iou )[0 ]
251- if idx_dec .size == 1 :
252- mask_increasing_iou [idx_dec [0 ] :] = False
278+ running_max = np .maximum .accumulate (ious [mask_gt ])
279+ mask_rmax = np .isclose (running_max , ious [mask_gt ])
280+ mask_rmax [1 :] &= running_max [1 :] > running_max [:- 1 ]
281+ mask_gt [mask_gt ] &= mask_rmax
253282
254- # define IOU lower bound
255- iou_boundary [mask_plabel ][1 :] = labeled_ious [:- 1 ]
256- iou_boundary [mask_plabel ][
257- ~ mask_increasing_iou
258- ] = 2.0 # arbitrary >1.0 value
283+ indices = np .where (mask_gt )[0 ]
259284
260- # mark first element (highest score)
261- indices = np .where (mask_gts )[0 ][1 :]
262- winning_predictions [indices ] = False
285+ iou_boundary [indices [0 ]] = 0.0
286+ iou_boundary [indices [1 :]] = ious [indices [:- 1 ]]
263287
264- return iou_boundary , winning_predictions
288+ return iou_boundary
265289
266290
267- def rank_table (tbl : pa .Table , number_of_labels : int ) -> pa .Table :
291+ def rank_table (tbl : pa .Table ) -> pa .Table :
292+ """Rank table for AP computation."""
268293 numeric_columns = [
269294 "datum_id" ,
270295 "gt_id" ,
@@ -278,24 +303,24 @@ def rank_table(tbl: pa.Table, number_of_labels: int) -> pa.Table:
278303 ("pd_score" , "descending" ),
279304 ("iou" , "descending" ),
280305 ]
306+
307+ # initial sort
281308 sorted_tbl = tbl .sort_by (sorting_args )
282309 pairs = np .column_stack (
283310 [sorted_tbl [col ].to_numpy () for col in numeric_columns ]
284311 )
285- pairs , indices = rank_pairs (pairs )
312+
313+ # rank pairs
314+ ranked_pairs , indices = rank_pairs (pairs )
286315 ranked_tbl = sorted_tbl .take (indices )
287- lower_iou_bound , winning_predictions = calculate_ranking_boundaries (
288- pairs , number_of_labels = number_of_labels
289- )
290- ranked_tbl = ranked_tbl .append_column (
291- pa .field ("high_score" , pa .bool_ ()),
292- pa .array (winning_predictions , type = pa .bool_ ()),
293- )
316+
317+ # find boundaries
318+ lower_iou_bound = calculate_ranking_boundaries (ranked_pairs )
294319 ranked_tbl = ranked_tbl .append_column (
295320 pa .field ("iou_prev" , pa .float64 ()),
296321 pa .array (lower_iou_bound , type = pa .float64 ()),
297322 )
298- ranked_tbl = ranked_tbl . sort_by ( sorting_args )
323+
299324 return ranked_tbl
300325
301326
@@ -306,41 +331,42 @@ def compute_counts(
306331 number_of_groundtruths_per_label : NDArray [np .uint64 ],
307332 number_of_labels : int ,
308333 running_counts : NDArray [np .uint64 ],
309- ) -> tuple :
334+ pr_curve : NDArray [np .float64 ],
335+ ) -> NDArray [np .uint64 ]:
310336 """
311337 Computes Object Detection metrics.
312338
313- Takes data with shape (N, 7):
314-
315- Index 0 - Datum Index
316- Index 1 - GroundTruth Index
317- Index 2 - Prediction Index
318- Index 3 - GroundTruth Label Index
319- Index 4 - Prediction Label Index
320- Index 5 - IOU
321- Index 6 - Score
322- Index 7 - IOU Lower Boundary
323- Index 8 - Winning Prediction
339+ Precision-recall curve and running counts are updated in-place.
324340
325341 Parameters
326342 ----------
327343 ranked_pairs : NDArray[np.float64]
328344 A ranked array summarizing the IOU calculations of one or more pairs.
345+ Index 0 - Datum Index
346+ Index 1 - GroundTruth Index
347+ Index 2 - Prediction Index
348+ Index 3 - GroundTruth Label Index
349+ Index 4 - Prediction Label Index
350+ Index 5 - IOU
351+ Index 6 - Score
352+ Index 7 - IOU Lower Boundary
329353 iou_thresholds : NDArray[np.float64]
330354 A 1-D array containing IOU thresholds.
331355 score_thresholds : NDArray[np.float64]
332356 A 1-D array containing score thresholds.
357+ number_of_groundtruths_per_label : NDArray[np.uint64]
358+ A 1-D array containing total number of ground truths per label.
359+ number_of_labels : int
360+ Total number of unique labels.
361+ running_counts : NDArray[np.uint64]
362+ A 2-D array containing running counts of total predictions and true-positive. This array is mutated.
363+ pr_curve : NDArray[np.float64]
364+ A 2-D array containing 101-point binning of precision and score over a fixed recall interval. This array is mutated.
333365
334366 Returns
335367 -------
336- tuple[NDArray[np.float64], NDArray[np.float64]]
337- Average Precision results (AP, mAP).
338- tuple[NDArray[np.float64], NDArray[np.float64]]
339- Average Recall results (AR, mAR).
340- NDArray[np.float64]
341- Precision, Recall, TP, FP, FN, F1 Score.
342- NDArray[np.float64]
343- Interpolated Precision-Recall Curves.
368+ NDArray[uint64]
369+ Batched counts of TP, FP, FN.
344370 """
345371 n_rows = ranked_pairs .shape [0 ]
346372 n_labels = number_of_labels
@@ -349,7 +375,6 @@ def compute_counts(
349375
350376 # initialize result arrays
351377 counts = np .zeros ((n_ious , n_scores , 3 , n_labels ), dtype = np .uint64 )
352- pr_curve = np .zeros ((n_ious , n_labels , 101 , 2 ))
353378
354379 # start computation
355380 ids = ranked_pairs [:, :5 ].astype (np .int64 )
@@ -359,7 +384,6 @@ def compute_counts(
359384 ious = ranked_pairs [:, 5 ]
360385 scores = ranked_pairs [:, 6 ]
361386 prev_ious = ranked_pairs [:, 7 ]
362- winners = ranked_pairs [:, 8 ].astype (np .bool_ )
363387
364388 unique_pd_labels , _ = np .unique (pd_labels , return_index = True )
365389
@@ -384,9 +408,9 @@ def compute_counts(
384408 mask_iou_prev = prev_ious < iou_thresholds [iou_idx ]
385409 mask_iou = mask_iou_curr & mask_iou_prev
386410
387- mask_tp_outer = mask_tp & mask_iou & winners
411+ mask_tp_outer = mask_tp & mask_iou
388412 mask_fp_outer = mask_fp & (
389- (~ mask_gt_exists_labels_match & mask_iou ) | ~ mask_iou | ~ winners
413+ (~ mask_gt_exists_labels_match & mask_iou ) | ~ mask_iou
390414 )
391415
392416 for score_idx in range (n_scores ):
@@ -421,33 +445,29 @@ def compute_counts(
421445 )
422446
423447 # create true-positive mask score threshold
424- tp_candidates = ids [mask_tp_outer ]
425- _ , indices_gt_unique = np .unique (
426- tp_candidates [:, [0 , 1 , 3 ]], axis = 0 , return_index = True
427- )
428- mask_gt_unique = np .zeros (tp_candidates .shape [0 ], dtype = np .bool_ )
429- mask_gt_unique [indices_gt_unique ] = True
430- true_positives_mask = np .zeros (n_rows , dtype = np .bool_ )
431- true_positives_mask [mask_tp_outer ] = mask_gt_unique
448+ mask_tps = mask_tp_outer
449+ true_positives_mask = mask_tps & mask_iou_prev
432450
433451 # count running tp and total for AP
434452 for pd_label in unique_pd_labels :
435453 mask_pd_label = pd_labels == pd_label
454+ total_count = mask_pd_label .sum ()
455+ if total_count == 0 :
456+ continue
436457
437458 # running total prediction count
438- total_count = mask_pd_label .sum ()
439- running_total_count [iou_idx ][mask_pd_label ] = np .arange (
440- running_counts [iou_idx , pd_label , 0 ],
441- running_counts [iou_idx , pd_label , 0 ] + total_count ,
459+ running_total_count [iou_idx , mask_pd_label ] = np .arange (
460+ running_counts [iou_idx , pd_label , 0 ] + 1 ,
461+ running_counts [iou_idx , pd_label , 0 ] + total_count + 1 ,
442462 )
443463 running_counts [iou_idx , pd_label , 0 ] += total_count
444464
445465 # running true-positive count
446466 mask_tp_for_counting = mask_pd_label & true_positives_mask
447467 tp_count = mask_tp_for_counting .sum ()
448- running_tp_count [iou_idx ][ mask_tp_for_counting ] = np .arange (
449- running_counts [iou_idx , pd_label , 1 ],
450- running_counts [iou_idx , pd_label , 1 ] + tp_count ,
468+ running_tp_count [iou_idx , mask_tp_for_counting ] = np .arange (
469+ running_counts [iou_idx , pd_label , 1 ] + 1 ,
470+ running_counts [iou_idx , pd_label , 1 ] + tp_count + 1 ,
451471 )
452472 running_counts [iou_idx , pd_label , 1 ] += tp_count
453473
@@ -474,15 +494,14 @@ def compute_counts(
474494 pr_curve [iou_idx , pd_labels , recall_index [iou_idx ], 0 ],
475495 precision [iou_idx ],
476496 )
477- pr_curve [iou_idx , pd_labels , recall_index [iou_idx ], 1 ] = np .maximum (
478- pr_curve [iou_idx , pd_labels , recall_index [iou_idx ], 1 ],
479- scores ,
497+ pr_curve [
498+ iou_idx , pd_labels [::- 1 ], recall_index [iou_idx ][::- 1 ], 1
499+ ] = np .maximum (
500+ pr_curve [iou_idx , pd_labels [::- 1 ], recall_index [iou_idx ][::- 1 ], 1 ],
501+ scores [::- 1 ],
480502 )
481503
482- return (
483- counts ,
484- pr_curve ,
485- )
504+ return counts
486505
487506
488507def compute_precision_recall_f1 (
0 commit comments