44import valor_lite .classification .numpy_compatibility as npc
55
66
7+ def compute_label_metadata (
8+ ids : NDArray [np .int32 ],
9+ n_labels : int ,
10+ ) -> NDArray [np .int32 ]:
11+ """
12+ Computes label metadata returning a count of annotations per label.
13+
14+ Parameters
15+ ----------
16+ detailed_pairs : NDArray[np.int32]
17+ Detailed annotation pairings with shape (n_pairs, 3).
18+ Index 0 - Datum Index
19+ Index 1 - GroundTruth Label Index
20+ Index 2 - Prediction Label Index
21+ n_labels : int
22+ The total number of unique labels.
23+
24+ Returns
25+ -------
26+ NDArray[np.int32]
27+ The label metadata array with shape (n_labels, 2).
28+ Index 0 - Ground truth label count
29+ Index 1 - Prediction label count
30+ """
31+ label_metadata = np .zeros ((n_labels , 2 ), dtype = np .int32 )
32+ ground_truth_pairs = ids [:, (0 , 1 )]
33+ ground_truth_pairs = ground_truth_pairs [ground_truth_pairs [:, 1 ] >= 0 ]
34+ unique_pairs = np .unique (ground_truth_pairs , axis = 0 )
35+ label_indices , unique_counts = np .unique (
36+ unique_pairs [:, 1 ], return_counts = True
37+ )
38+ label_metadata [label_indices .astype (np .int32 ), 0 ] = unique_counts
39+
40+ prediction_pairs = ids [:, (0 , 2 )]
41+ prediction_pairs = prediction_pairs [prediction_pairs [:, 1 ] >= 0 ]
42+ unique_pairs = np .unique (prediction_pairs , axis = 0 )
43+ label_indices , unique_counts = np .unique (
44+ unique_pairs [:, 1 ], return_counts = True
45+ )
46+ label_metadata [label_indices .astype (np .int32 ), 1 ] = unique_counts
47+
48+ return label_metadata
49+
50+
51+ def filter_cache (
52+ detailed_pairs : NDArray [np .float64 ],
53+ datum_mask : NDArray [np .bool_ ],
54+ valid_label_indices : NDArray [np .int32 ] | None ,
55+ n_labels : int ,
56+ ) -> tuple [NDArray [np .float64 ], NDArray [np .int32 ]]:
57+ # filter by datum
58+ detailed_pairs = detailed_pairs [datum_mask ].copy ()
59+
60+ n_rows = detailed_pairs .shape [0 ]
61+ mask_invalid_groundtruths = np .zeros (n_rows , dtype = np .bool_ )
62+ mask_invalid_predictions = np .zeros_like (mask_invalid_groundtruths )
63+
64+ # filter labels
65+ if valid_label_indices is not None :
66+ mask_invalid_groundtruths [
67+ ~ np .isin (detailed_pairs [:, 1 ], valid_label_indices )
68+ ] = True
69+ mask_invalid_predictions [
70+ ~ np .isin (detailed_pairs [:, 2 ], valid_label_indices )
71+ ] = True
72+
73+ # filter cache
74+ if mask_invalid_groundtruths .any ():
75+ invalid_groundtruth_indices = np .where (mask_invalid_groundtruths )[0 ]
76+ detailed_pairs [invalid_groundtruth_indices [:, None ], 1 ] = np .array (
77+ [[- 1.0 ]]
78+ )
79+
80+ if mask_invalid_predictions .any ():
81+ invalid_prediction_indices = np .where (mask_invalid_predictions )[0 ]
82+ detailed_pairs [
83+ invalid_prediction_indices [:, None ], (2 , 3 , 4 )
84+ ] = np .array ([[- 1.0 , - 1.0 , - 1.0 ]])
85+
86+ # filter null pairs
87+ mask_null_pairs = np .all (
88+ np .isclose (
89+ detailed_pairs [:, 1 :5 ],
90+ np .array ([- 1.0 , - 1.0 , - 1.0 , - 1.0 ]),
91+ ),
92+ axis = 1 ,
93+ )
94+ detailed_pairs = detailed_pairs [~ mask_null_pairs ]
95+
96+ detailed_pairs = np .unique (detailed_pairs , axis = 0 )
97+ indices = np .lexsort (
98+ (
99+ detailed_pairs [:, 1 ], # ground truth
100+ detailed_pairs [:, 2 ], # prediction
101+ - detailed_pairs [:, 3 ], # score
102+ )
103+ )
104+ detailed_pairs = detailed_pairs [indices ]
105+ label_metadata = compute_label_metadata (
106+ ids = detailed_pairs [:, :3 ].astype (np .int32 ),
107+ n_labels = n_labels ,
108+ )
109+ return detailed_pairs , label_metadata
110+
111+
7112def _compute_rocauc (
8113 data : NDArray [np .float64 ],
9114 label_metadata : NDArray [np .int32 ],
@@ -67,7 +172,7 @@ def _compute_rocauc(
67172
68173
69174def compute_precision_recall_rocauc (
70- data : NDArray [np .float64 ],
175+ detailed_pairs : NDArray [np .float64 ],
71176 label_metadata : NDArray [np .int32 ],
72177 score_thresholds : NDArray [np .float64 ],
73178 hardmax : bool ,
@@ -84,20 +189,19 @@ def compute_precision_recall_rocauc(
84189 """
85190 Computes classification metrics.
86191
87- Takes data with shape (N, 5):
88-
89- Index 0 - Datum Index
90- Index 1 - GroundTruth Label Index
91- Index 2 - Prediction Label Index
92- Index 3 - Score
93- Index 4 - Hard-Max Score
94-
95192 Parameters
96193 ----------
97- data : NDArray[np.float64]
98- A sorted array of classification pairs.
194+ detailed_pairs : NDArray[np.float64]
195+ A sorted array of classification pairs with shape (n_pairs, 5).
196+ Index 0 - Datum Index
197+ Index 1 - GroundTruth Label Index
198+ Index 2 - Prediction Label Index
199+ Index 3 - Score
200+ Index 4 - Hard-Max Score
99201 label_metadata : NDArray[np.int32]
100- An array containing metadata related to labels.
202+ An array containing metadata related to labels with shape (n_labels, 2).
203+ Index 0 - GroundTruth Label Count
204+ Index 1 - Prediction Label Count
101205 score_thresholds : NDArray[np.float64]
102206 A 1-D array contains score thresholds to compute metrics over.
103207 hardmax : bool
@@ -126,15 +230,17 @@ def compute_precision_recall_rocauc(
126230 n_labels = label_metadata .shape [0 ]
127231 n_scores = score_thresholds .shape [0 ]
128232
129- pd_labels = data [:, 2 ].astype (int )
233+ pd_labels = detailed_pairs [:, 2 ].astype (int )
130234
131- mask_matching_labels = np .isclose (data [:, 1 ], data [:, 2 ])
132- mask_score_nonzero = ~ np .isclose (data [:, 3 ], 0.0 )
133- mask_hardmax = data [:, 4 ] > 0.5
235+ mask_matching_labels = np .isclose (
236+ detailed_pairs [:, 1 ], detailed_pairs [:, 2 ]
237+ )
238+ mask_score_nonzero = ~ np .isclose (detailed_pairs [:, 3 ], 0.0 )
239+ mask_hardmax = detailed_pairs [:, 4 ] > 0.5
134240
135241 # calculate ROCAUC
136242 rocauc , mean_rocauc = _compute_rocauc (
137- data = data ,
243+ data = detailed_pairs ,
138244 label_metadata = label_metadata ,
139245 n_datums = n_datums ,
140246 n_labels = n_labels ,
@@ -145,7 +251,9 @@ def compute_precision_recall_rocauc(
145251 # calculate metrics at various score thresholds
146252 counts = np .zeros ((n_scores , n_labels , 4 ), dtype = np .int32 )
147253 for score_idx in range (n_scores ):
148- mask_score_threshold = data [:, 3 ] >= score_thresholds [score_idx ]
254+ mask_score_threshold = (
255+ detailed_pairs [:, 3 ] >= score_thresholds [score_idx ]
256+ )
149257 mask_score = mask_score_nonzero & mask_score_threshold
150258
151259 if hardmax :
@@ -156,8 +264,8 @@ def compute_precision_recall_rocauc(
156264 mask_fn = (mask_matching_labels & ~ mask_score ) | mask_fp
157265 mask_tn = ~ mask_matching_labels & ~ mask_score
158266
159- fn = np .unique (data [mask_fn ][:, [0 , 1 ]].astype (int ), axis = 0 )
160- tn = np .unique (data [mask_tn ][:, [0 , 2 ]].astype (int ), axis = 0 )
267+ fn = np .unique (detailed_pairs [mask_fn ][:, [0 , 1 ]].astype (int ), axis = 0 )
268+ tn = np .unique (detailed_pairs [mask_tn ][:, [0 , 2 ]].astype (int ), axis = 0 )
161269
162270 counts [score_idx , :, 0 ] = np .bincount (
163271 pd_labels [mask_tp ], minlength = n_labels
@@ -249,7 +357,7 @@ def _count_with_examples(
249357
250358
251359def compute_confusion_matrix (
252- data : NDArray [np .float64 ],
360+ detailed_pairs : NDArray [np .float64 ],
253361 label_metadata : NDArray [np .int32 ],
254362 score_thresholds : NDArray [np .float64 ],
255363 hardmax : bool ,
@@ -260,18 +368,19 @@ def compute_confusion_matrix(
260368
261369 Takes data with shape (N, 5):
262370
263- Index 0 - Datum Index
264- Index 1 - GroundTruth Label Index
265- Index 2 - Prediction Label Index
266- Index 3 - Score
267- Index 4 - Hard Max Score
268-
269371 Parameters
270372 ----------
271- data : NDArray[np.float64]
272- A sorted array summarizing the IOU calculations of one or more pairs.
373+ detailed_pairs : NDArray[np.float64]
374+ A 2-D sorted array summarizing the IOU calculations of one or more pairs with shape (n_pairs, 5).
375+ Index 0 - Datum Index
376+ Index 1 - GroundTruth Label Index
377+ Index 2 - Prediction Label Index
378+ Index 3 - Score
379+ Index 4 - Hard Max Score
273380 label_metadata : NDArray[np.int32]
274- An array containing metadata related to labels.
381+ A 2-D array containing metadata related to labels with shape (n_labels, 2).
382+ Index 0 - GroundTruth Label Count
383+ Index 1 - Prediction Label Count
275384 iou_thresholds : NDArray[np.float64]
276385 A 1-D array containing IOU thresholds.
277386 score_thresholds : NDArray[np.float64]
@@ -301,15 +410,15 @@ def compute_confusion_matrix(
301410 dtype = np .int32 ,
302411 )
303412
304- mask_label_match = np .isclose (data [:, 1 ], data [:, 2 ])
305- mask_score = data [:, 3 ] > 1e-9
413+ mask_label_match = np .isclose (detailed_pairs [:, 1 ], detailed_pairs [:, 2 ])
414+ mask_score = detailed_pairs [:, 3 ] > 1e-9
306415
307- groundtruths = data [:, [0 , 1 ]].astype (int )
416+ groundtruths = detailed_pairs [:, [0 , 1 ]].astype (int )
308417
309418 for score_idx in range (n_scores ):
310- mask_score &= data [:, 3 ] >= score_thresholds [score_idx ]
419+ mask_score &= detailed_pairs [:, 3 ] >= score_thresholds [score_idx ]
311420 if hardmax :
312- mask_score &= data [:, 4 ] > 0.5
421+ mask_score &= detailed_pairs [:, 4 ] > 0.5
313422
314423 mask_tp = mask_label_match & mask_score
315424 mask_misclf = ~ mask_label_match & mask_score
@@ -323,17 +432,17 @@ def compute_confusion_matrix(
323432 )
324433
325434 tp_examples , tp_labels , tp_counts = _count_with_examples (
326- data = data [mask_tp ],
435+ data = detailed_pairs [mask_tp ],
327436 unique_idx = [0 , 2 ],
328437 label_idx = 1 ,
329438 )
330439 misclf_examples , misclf_labels , misclf_counts = _count_with_examples (
331- data = data [mask_misclf ],
440+ data = detailed_pairs [mask_misclf ],
332441 unique_idx = [0 , 1 , 2 ],
333442 label_idx = [1 , 2 ],
334443 )
335444 misprd_examples , misprd_labels , misprd_counts = _count_with_examples (
336- data = data [mask_misprd ],
445+ data = detailed_pairs [mask_misprd ],
337446 unique_idx = [0 , 1 ],
338447 label_idx = 1 ,
339448 )
0 commit comments