Skip to content

Commit 67832a9

Browse files
authored
Fix classification filtering (#866)
1 parent e8d329d commit 67832a9

File tree

5 files changed

+483
-28
lines changed

5 files changed

+483
-28
lines changed

src/valor_lite/classification/evaluator.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -487,19 +487,12 @@ def filter(
487487
else:
488488
mask_valid_pd = np.ones(n_pairs, dtype=np.bool_)
489489

490-
mask_valid = mask_valid_gt | mask_valid_pd
491-
mask_valid_gt &= mask_valid
492-
mask_valid_pd &= mask_valid
490+
# classifications *must* have a pairing
491+
mask_valid = mask_valid_gt & mask_valid_pd
492+
filtered_tbl = tbl.filter(pa.array(mask_valid))
493493

494-
pairs[~mask_valid_gt, 1] = -1
495-
pairs[~mask_valid_pd, 2] = -1
496-
497-
for idx, col in enumerate(columns):
498-
tbl = tbl.set_column(
499-
tbl.schema.names.index(col), col, pa.array(pairs[:, idx])
500-
)
501494
# TODO (c.zaloom) - improve write strategy, filtered data could be small
502-
loader._writer.write_table(tbl)
495+
loader._writer.write_table(filtered_tbl)
503496

504497
return loader.finalize(index_to_label_override=self._index_to_label)
505498

tests/classification/test_filtering.py

Lines changed: 117 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,8 @@ def test_filtering_six_classifications_by_annotation(
410410
loader.add_data(six_classifications)
411411
evaluator = loader.finalize()
412412

413-
# test groundtruth filter
413+
# test groundtruth filter - this will completely filter out datums with ground truth label == "0"
414414
filtered_evaluator = evaluator.filter(
415-
datums=pc.field("datum_uid") == "uid0",
416415
groundtruths=pc.field("gt_label") != "0",
417416
path=tmp_path / "groundtruth_filter",
418417
)
@@ -426,9 +425,9 @@ def test_filtering_six_classifications_by_annotation(
426425
"type": "Counts",
427426
"value": {
428427
"tp": 0,
429-
"fp": 1,
428+
"fp": 0,
430429
"fn": 0,
431-
"tn": 0,
430+
"tn": 2,
432431
},
433432
"parameters": {
434433
"score_threshold": 0.5,
@@ -442,7 +441,7 @@ def test_filtering_six_classifications_by_annotation(
442441
"tp": 0,
443442
"fp": 0,
444443
"fn": 0,
445-
"tn": 1,
444+
"tn": 2,
446445
},
447446
"parameters": {
448447
"score_threshold": 0.5,
@@ -456,7 +455,7 @@ def test_filtering_six_classifications_by_annotation(
456455
"tp": 0,
457456
"fp": 0,
458457
"fn": 0,
459-
"tn": 1,
458+
"tn": 2,
460459
},
461460
"parameters": {
462461
"score_threshold": 0.5,
@@ -469,8 +468,8 @@ def test_filtering_six_classifications_by_annotation(
469468
"value": {
470469
"tp": 0,
471470
"fp": 0,
472-
"fn": 0,
473-
"tn": 1,
471+
"fn": 2,
472+
"tn": 0,
474473
},
475474
"parameters": {
476475
"score_threshold": 0.5,
@@ -484,9 +483,8 @@ def test_filtering_six_classifications_by_annotation(
484483
for m in expected_metrics:
485484
assert m in actual_metrics
486485

487-
# test prediction filter
486+
# test prediction filter - this does not filter datums since more than one pred label exists
488487
filtered_evaluator = evaluator.filter(
489-
datums=pc.field("datum_uid") == "uid0",
490488
predictions=pc.field("pd_label") != "0",
491489
path=tmp_path / "prediction_filter",
492490
)
@@ -501,7 +499,7 @@ def test_filtering_six_classifications_by_annotation(
501499
"value": {
502500
"tp": 0,
503501
"fp": 0,
504-
"fn": 1,
502+
"fn": 2,
505503
"tn": 0,
506504
},
507505
"parameters": {
@@ -516,7 +514,7 @@ def test_filtering_six_classifications_by_annotation(
516514
"tp": 0,
517515
"fp": 0,
518516
"fn": 0,
519-
"tn": 1,
517+
"tn": 6,
520518
},
521519
"parameters": {
522520
"score_threshold": 0.5,
@@ -528,9 +526,9 @@ def test_filtering_six_classifications_by_annotation(
528526
"type": "Counts",
529527
"value": {
530528
"tp": 0,
531-
"fp": 0,
529+
"fp": 2,
532530
"fn": 0,
533-
"tn": 1,
531+
"tn": 4,
534532
},
535533
"parameters": {
536534
"score_threshold": 0.5,
@@ -543,8 +541,8 @@ def test_filtering_six_classifications_by_annotation(
543541
"value": {
544542
"tp": 0,
545543
"fp": 0,
546-
"fn": 0,
547-
"tn": 1,
544+
"fn": 2,
545+
"tn": 4,
548546
},
549547
"parameters": {
550548
"score_threshold": 0.5,
@@ -682,3 +680,106 @@ def test_filtering_remove_all(
682680
for v in example.values():
683681
if isinstance(v, list):
684682
assert len(v) == 0
683+
684+
685+
def test_filtering_labels(
686+
loader: Loader,
687+
classifications_animal_example: list[Classification],
688+
tmp_path: Path,
689+
):
690+
loader.add_data(classifications_animal_example)
691+
evaluator = loader.finalize()
692+
693+
assert evaluator._index_to_label == {
694+
0: "bird",
695+
1: "dog",
696+
2: "cat",
697+
}
698+
assert evaluator.compute_precision_recall()
699+
assert evaluator.compute_rocauc()
700+
assert evaluator.compute_confusion_matrix()
701+
assert evaluator.compute_examples()
702+
703+
cm = evaluator.compute_confusion_matrix()
704+
assert len(cm) == 1
705+
assert cm[0].to_dict() == {
706+
"parameters": {
707+
"hardmax": True,
708+
"score_threshold": 0.0,
709+
},
710+
"type": "ConfusionMatrix",
711+
"value": {
712+
"confusion_matrix": {
713+
"bird": {
714+
"bird": 1,
715+
"cat": 1,
716+
"dog": 1,
717+
},
718+
"cat": {
719+
"bird": 0,
720+
"cat": 1,
721+
"dog": 0,
722+
},
723+
"dog": {
724+
"bird": 0,
725+
"cat": 2,
726+
"dog": 0,
727+
},
728+
},
729+
"unmatched_ground_truths": {
730+
"bird": 0,
731+
"cat": 0,
732+
"dog": 0,
733+
},
734+
},
735+
}
736+
737+
filtered = evaluator.filter(
738+
groundtruths=pc.field("gt_label").isin(["bird", "dog"]),
739+
predictions=pc.field("pd_label").isin(["bird", "dog"]),
740+
path=tmp_path / "filter",
741+
)
742+
743+
assert filtered._index_to_label == {
744+
0: "bird",
745+
1: "dog",
746+
2: "cat",
747+
}
748+
assert filtered.compute_precision_recall()
749+
assert filtered.compute_rocauc()
750+
assert filtered.compute_confusion_matrix()
751+
assert filtered.compute_examples()
752+
753+
cm = filtered.compute_confusion_matrix()
754+
assert len(cm) == 1
755+
assert cm[0].to_dict() == {
756+
"parameters": {
757+
"hardmax": True,
758+
"score_threshold": 0.0,
759+
},
760+
"type": "ConfusionMatrix",
761+
"value": {
762+
"confusion_matrix": {
763+
"bird": {
764+
"bird": 1,
765+
"cat": 0,
766+
"dog": 1,
767+
},
768+
"cat": {
769+
"bird": 0,
770+
"cat": 0,
771+
"dog": 0,
772+
},
773+
"dog": {
774+
"bird": 0,
775+
"cat": 0,
776+
"dog": 0,
777+
},
778+
},
779+
"unmatched_ground_truths": {
780+
"bird": 1,
781+
"cat": 0,
782+
"dog": 2,
783+
},
784+
},
785+
}

0 commit comments

Comments
 (0)