Skip to content

Commit f581aa5

Browse files
authored
Update Classification to match Object Detection v0.35.0 (#840)
1 parent 5518b2d commit f581aa5

File tree

16 files changed

+883
-630
lines changed

16 files changed

+883
-630
lines changed

benchmarks/benchmark_classification.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def run_benchmarking_analysis(
215215
eval_time, _ = time_it(evaluator.compute_precision_recall_rocauc)()
216216
if eval_time > evaluation_timeout and evaluation_timeout != -1:
217217
raise TimeoutError(
218-
f"Base evaluation timed out with {evaluator.n_datums} datums."
218+
f"Base evaluation timed out with {evaluator.metadata.number_of_datums} datums."
219219
)
220220

221221
detail_no_examples_time, _ = time_it(
@@ -228,7 +228,7 @@ def run_benchmarking_analysis(
228228
and evaluation_timeout != -1
229229
):
230230
raise TimeoutError(
231-
f"Base evaluation timed out with {evaluator.n_datums} datums."
231+
f"Base evaluation timed out with {evaluator.metadata.number_of_datums} datums."
232232
)
233233

234234
detail_three_examples_time, _ = time_it(
@@ -241,16 +241,16 @@ def run_benchmarking_analysis(
241241
and evaluation_timeout != -1
242242
):
243243
raise TimeoutError(
244-
f"Base evaluation timed out with {evaluator.n_datums} datums."
244+
f"Base evaluation timed out with {evaluator.metadata.number_of_datums} datums."
245245
)
246246

247247
results.append(
248248
Benchmark(
249249
limit=limit,
250-
n_datums=evaluator.n_datums,
251-
n_groundtruths=evaluator.n_groundtruths,
252-
n_predictions=evaluator.n_predictions,
253-
n_labels=evaluator.n_labels,
250+
n_datums=evaluator.metadata.number_of_datums,
251+
n_groundtruths=evaluator.metadata.number_of_ground_truths,
252+
n_predictions=evaluator.metadata.number_of_predictions,
253+
n_labels=evaluator.metadata.number_of_labels,
254254
chunk_size=chunk_size,
255255
ingestion=ingest_time,
256256
preprocessing=preprocessing_time,

src/valor_lite/classification/computation.py

Lines changed: 147 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,111 @@
44
import valor_lite.classification.numpy_compatibility as npc
55

66

7+
def compute_label_metadata(
8+
ids: NDArray[np.int32],
9+
n_labels: int,
10+
) -> NDArray[np.int32]:
11+
"""
12+
Computes label metadata returning a count of annotations per label.
13+
14+
Parameters
15+
----------
16+
detailed_pairs : NDArray[np.int32]
17+
Detailed annotation pairings with shape (n_pairs, 3).
18+
Index 0 - Datum Index
19+
Index 1 - GroundTruth Label Index
20+
Index 2 - Prediction Label Index
21+
n_labels : int
22+
The total number of unique labels.
23+
24+
Returns
25+
-------
26+
NDArray[np.int32]
27+
The label metadata array with shape (n_labels, 2).
28+
Index 0 - Ground truth label count
29+
Index 1 - Prediction label count
30+
"""
31+
label_metadata = np.zeros((n_labels, 2), dtype=np.int32)
32+
ground_truth_pairs = ids[:, (0, 1)]
33+
ground_truth_pairs = ground_truth_pairs[ground_truth_pairs[:, 1] >= 0]
34+
unique_pairs = np.unique(ground_truth_pairs, axis=0)
35+
label_indices, unique_counts = np.unique(
36+
unique_pairs[:, 1], return_counts=True
37+
)
38+
label_metadata[label_indices.astype(np.int32), 0] = unique_counts
39+
40+
prediction_pairs = ids[:, (0, 2)]
41+
prediction_pairs = prediction_pairs[prediction_pairs[:, 1] >= 0]
42+
unique_pairs = np.unique(prediction_pairs, axis=0)
43+
label_indices, unique_counts = np.unique(
44+
unique_pairs[:, 1], return_counts=True
45+
)
46+
label_metadata[label_indices.astype(np.int32), 1] = unique_counts
47+
48+
return label_metadata
49+
50+
51+
def filter_cache(
52+
detailed_pairs: NDArray[np.float64],
53+
datum_mask: NDArray[np.bool_],
54+
valid_label_indices: NDArray[np.int32] | None,
55+
n_labels: int,
56+
) -> tuple[NDArray[np.float64], NDArray[np.int32]]:
57+
# filter by datum
58+
detailed_pairs = detailed_pairs[datum_mask].copy()
59+
60+
n_rows = detailed_pairs.shape[0]
61+
mask_invalid_groundtruths = np.zeros(n_rows, dtype=np.bool_)
62+
mask_invalid_predictions = np.zeros_like(mask_invalid_groundtruths)
63+
64+
# filter labels
65+
if valid_label_indices is not None:
66+
mask_invalid_groundtruths[
67+
~np.isin(detailed_pairs[:, 1], valid_label_indices)
68+
] = True
69+
mask_invalid_predictions[
70+
~np.isin(detailed_pairs[:, 2], valid_label_indices)
71+
] = True
72+
73+
# filter cache
74+
if mask_invalid_groundtruths.any():
75+
invalid_groundtruth_indices = np.where(mask_invalid_groundtruths)[0]
76+
detailed_pairs[invalid_groundtruth_indices[:, None], 1] = np.array(
77+
[[-1.0]]
78+
)
79+
80+
if mask_invalid_predictions.any():
81+
invalid_prediction_indices = np.where(mask_invalid_predictions)[0]
82+
detailed_pairs[
83+
invalid_prediction_indices[:, None], (2, 3, 4)
84+
] = np.array([[-1.0, -1.0, -1.0]])
85+
86+
# filter null pairs
87+
mask_null_pairs = np.all(
88+
np.isclose(
89+
detailed_pairs[:, 1:5],
90+
np.array([-1.0, -1.0, -1.0, -1.0]),
91+
),
92+
axis=1,
93+
)
94+
detailed_pairs = detailed_pairs[~mask_null_pairs]
95+
96+
detailed_pairs = np.unique(detailed_pairs, axis=0)
97+
indices = np.lexsort(
98+
(
99+
detailed_pairs[:, 1], # ground truth
100+
detailed_pairs[:, 2], # prediction
101+
-detailed_pairs[:, 3], # score
102+
)
103+
)
104+
detailed_pairs = detailed_pairs[indices]
105+
label_metadata = compute_label_metadata(
106+
ids=detailed_pairs[:, :3].astype(np.int32),
107+
n_labels=n_labels,
108+
)
109+
return detailed_pairs, label_metadata
110+
111+
7112
def _compute_rocauc(
8113
data: NDArray[np.float64],
9114
label_metadata: NDArray[np.int32],
@@ -67,7 +172,7 @@ def _compute_rocauc(
67172

68173

69174
def compute_precision_recall_rocauc(
70-
data: NDArray[np.float64],
175+
detailed_pairs: NDArray[np.float64],
71176
label_metadata: NDArray[np.int32],
72177
score_thresholds: NDArray[np.float64],
73178
hardmax: bool,
@@ -84,20 +189,19 @@ def compute_precision_recall_rocauc(
84189
"""
85190
Computes classification metrics.
86191
87-
Takes data with shape (N, 5):
88-
89-
Index 0 - Datum Index
90-
Index 1 - GroundTruth Label Index
91-
Index 2 - Prediction Label Index
92-
Index 3 - Score
93-
Index 4 - Hard-Max Score
94-
95192
Parameters
96193
----------
97-
data : NDArray[np.float64]
98-
A sorted array of classification pairs.
194+
detailed_pairs : NDArray[np.float64]
195+
A sorted array of classification pairs with shape (n_pairs, 5).
196+
Index 0 - Datum Index
197+
Index 1 - GroundTruth Label Index
198+
Index 2 - Prediction Label Index
199+
Index 3 - Score
200+
Index 4 - Hard-Max Score
99201
label_metadata : NDArray[np.int32]
100-
An array containing metadata related to labels.
202+
An array containing metadata related to labels with shape (n_labels, 2).
203+
Index 0 - GroundTruth Label Count
204+
Index 1 - Prediction Label Count
101205
score_thresholds : NDArray[np.float64]
102206
A 1-D array contains score thresholds to compute metrics over.
103207
hardmax : bool
@@ -126,15 +230,17 @@ def compute_precision_recall_rocauc(
126230
n_labels = label_metadata.shape[0]
127231
n_scores = score_thresholds.shape[0]
128232

129-
pd_labels = data[:, 2].astype(int)
233+
pd_labels = detailed_pairs[:, 2].astype(int)
130234

131-
mask_matching_labels = np.isclose(data[:, 1], data[:, 2])
132-
mask_score_nonzero = ~np.isclose(data[:, 3], 0.0)
133-
mask_hardmax = data[:, 4] > 0.5
235+
mask_matching_labels = np.isclose(
236+
detailed_pairs[:, 1], detailed_pairs[:, 2]
237+
)
238+
mask_score_nonzero = ~np.isclose(detailed_pairs[:, 3], 0.0)
239+
mask_hardmax = detailed_pairs[:, 4] > 0.5
134240

135241
# calculate ROCAUC
136242
rocauc, mean_rocauc = _compute_rocauc(
137-
data=data,
243+
data=detailed_pairs,
138244
label_metadata=label_metadata,
139245
n_datums=n_datums,
140246
n_labels=n_labels,
@@ -145,7 +251,9 @@ def compute_precision_recall_rocauc(
145251
# calculate metrics at various score thresholds
146252
counts = np.zeros((n_scores, n_labels, 4), dtype=np.int32)
147253
for score_idx in range(n_scores):
148-
mask_score_threshold = data[:, 3] >= score_thresholds[score_idx]
254+
mask_score_threshold = (
255+
detailed_pairs[:, 3] >= score_thresholds[score_idx]
256+
)
149257
mask_score = mask_score_nonzero & mask_score_threshold
150258

151259
if hardmax:
@@ -156,8 +264,8 @@ def compute_precision_recall_rocauc(
156264
mask_fn = (mask_matching_labels & ~mask_score) | mask_fp
157265
mask_tn = ~mask_matching_labels & ~mask_score
158266

159-
fn = np.unique(data[mask_fn][:, [0, 1]].astype(int), axis=0)
160-
tn = np.unique(data[mask_tn][:, [0, 2]].astype(int), axis=0)
267+
fn = np.unique(detailed_pairs[mask_fn][:, [0, 1]].astype(int), axis=0)
268+
tn = np.unique(detailed_pairs[mask_tn][:, [0, 2]].astype(int), axis=0)
161269

162270
counts[score_idx, :, 0] = np.bincount(
163271
pd_labels[mask_tp], minlength=n_labels
@@ -249,7 +357,7 @@ def _count_with_examples(
249357

250358

251359
def compute_confusion_matrix(
252-
data: NDArray[np.float64],
360+
detailed_pairs: NDArray[np.float64],
253361
label_metadata: NDArray[np.int32],
254362
score_thresholds: NDArray[np.float64],
255363
hardmax: bool,
@@ -260,18 +368,19 @@ def compute_confusion_matrix(
260368
261369
Takes data with shape (N, 5):
262370
263-
Index 0 - Datum Index
264-
Index 1 - GroundTruth Label Index
265-
Index 2 - Prediction Label Index
266-
Index 3 - Score
267-
Index 4 - Hard Max Score
268-
269371
Parameters
270372
----------
271-
data : NDArray[np.float64]
272-
A sorted array summarizing the IOU calculations of one or more pairs.
373+
detailed_pairs : NDArray[np.float64]
374+
A 2-D sorted array summarizing the IOU calculations of one or more pairs with shape (n_pairs, 5).
375+
Index 0 - Datum Index
376+
Index 1 - GroundTruth Label Index
377+
Index 2 - Prediction Label Index
378+
Index 3 - Score
379+
Index 4 - Hard Max Score
273380
label_metadata : NDArray[np.int32]
274-
An array containing metadata related to labels.
381+
A 2-D array containing metadata related to labels with shape (n_labels, 2).
382+
Index 0 - GroundTruth Label Count
383+
Index 1 - Prediction Label Count
275384
iou_thresholds : NDArray[np.float64]
276385
A 1-D array containing IOU thresholds.
277386
score_thresholds : NDArray[np.float64]
@@ -301,15 +410,15 @@ def compute_confusion_matrix(
301410
dtype=np.int32,
302411
)
303412

304-
mask_label_match = np.isclose(data[:, 1], data[:, 2])
305-
mask_score = data[:, 3] > 1e-9
413+
mask_label_match = np.isclose(detailed_pairs[:, 1], detailed_pairs[:, 2])
414+
mask_score = detailed_pairs[:, 3] > 1e-9
306415

307-
groundtruths = data[:, [0, 1]].astype(int)
416+
groundtruths = detailed_pairs[:, [0, 1]].astype(int)
308417

309418
for score_idx in range(n_scores):
310-
mask_score &= data[:, 3] >= score_thresholds[score_idx]
419+
mask_score &= detailed_pairs[:, 3] >= score_thresholds[score_idx]
311420
if hardmax:
312-
mask_score &= data[:, 4] > 0.5
421+
mask_score &= detailed_pairs[:, 4] > 0.5
313422

314423
mask_tp = mask_label_match & mask_score
315424
mask_misclf = ~mask_label_match & mask_score
@@ -323,17 +432,17 @@ def compute_confusion_matrix(
323432
)
324433

325434
tp_examples, tp_labels, tp_counts = _count_with_examples(
326-
data=data[mask_tp],
435+
data=detailed_pairs[mask_tp],
327436
unique_idx=[0, 2],
328437
label_idx=1,
329438
)
330439
misclf_examples, misclf_labels, misclf_counts = _count_with_examples(
331-
data=data[mask_misclf],
440+
data=detailed_pairs[mask_misclf],
332441
unique_idx=[0, 1, 2],
333442
label_idx=[1, 2],
334443
)
335444
misprd_examples, misprd_labels, misprd_counts = _count_with_examples(
336-
data=data[mask_misprd],
445+
data=detailed_pairs[mask_misprd],
337446
unique_idx=[0, 1],
338447
label_idx=1,
339448
)

0 commit comments

Comments
 (0)