@@ -410,9 +410,8 @@ def test_filtering_six_classifications_by_annotation(
410410 loader .add_data (six_classifications )
411411 evaluator = loader .finalize ()
412412
413- # test groundtruth filter
413+ # test groundtruth filter - this will completely filter out datums with ground truth label == "0"
414414 filtered_evaluator = evaluator .filter (
415- datums = pc .field ("datum_uid" ) == "uid0" ,
416415 groundtruths = pc .field ("gt_label" ) != "0" ,
417416 path = tmp_path / "groundtruth_filter" ,
418417 )
@@ -426,9 +425,9 @@ def test_filtering_six_classifications_by_annotation(
426425 "type" : "Counts" ,
427426 "value" : {
428427 "tp" : 0 ,
429- "fp" : 1 ,
428+ "fp" : 0 ,
430429 "fn" : 0 ,
431- "tn" : 0 ,
430+ "tn" : 2 ,
432431 },
433432 "parameters" : {
434433 "score_threshold" : 0.5 ,
@@ -442,7 +441,7 @@ def test_filtering_six_classifications_by_annotation(
442441 "tp" : 0 ,
443442 "fp" : 0 ,
444443 "fn" : 0 ,
445- "tn" : 1 ,
444+ "tn" : 2 ,
446445 },
447446 "parameters" : {
448447 "score_threshold" : 0.5 ,
@@ -456,7 +455,7 @@ def test_filtering_six_classifications_by_annotation(
456455 "tp" : 0 ,
457456 "fp" : 0 ,
458457 "fn" : 0 ,
459- "tn" : 1 ,
458+ "tn" : 2 ,
460459 },
461460 "parameters" : {
462461 "score_threshold" : 0.5 ,
@@ -469,8 +468,8 @@ def test_filtering_six_classifications_by_annotation(
469468 "value" : {
470469 "tp" : 0 ,
471470 "fp" : 0 ,
472- "fn" : 0 ,
473- "tn" : 1 ,
471+ "fn" : 2 ,
472+ "tn" : 0 ,
474473 },
475474 "parameters" : {
476475 "score_threshold" : 0.5 ,
@@ -484,9 +483,8 @@ def test_filtering_six_classifications_by_annotation(
484483 for m in expected_metrics :
485484 assert m in actual_metrics
486485
487- # test prediction filter
486+ # test prediction filter - this does not filter datums since more than one pred label exists
488487 filtered_evaluator = evaluator .filter (
489- datums = pc .field ("datum_uid" ) == "uid0" ,
490488 predictions = pc .field ("pd_label" ) != "0" ,
491489 path = tmp_path / "prediction_filter" ,
492490 )
@@ -501,7 +499,7 @@ def test_filtering_six_classifications_by_annotation(
501499 "value" : {
502500 "tp" : 0 ,
503501 "fp" : 0 ,
504- "fn" : 1 ,
502+ "fn" : 2 ,
505503 "tn" : 0 ,
506504 },
507505 "parameters" : {
@@ -516,7 +514,7 @@ def test_filtering_six_classifications_by_annotation(
516514 "tp" : 0 ,
517515 "fp" : 0 ,
518516 "fn" : 0 ,
519- "tn" : 1 ,
517+ "tn" : 6 ,
520518 },
521519 "parameters" : {
522520 "score_threshold" : 0.5 ,
@@ -528,9 +526,9 @@ def test_filtering_six_classifications_by_annotation(
528526 "type" : "Counts" ,
529527 "value" : {
530528 "tp" : 0 ,
531- "fp" : 0 ,
529+ "fp" : 2 ,
532530 "fn" : 0 ,
533- "tn" : 1 ,
531+ "tn" : 4 ,
534532 },
535533 "parameters" : {
536534 "score_threshold" : 0.5 ,
@@ -543,8 +541,8 @@ def test_filtering_six_classifications_by_annotation(
543541 "value" : {
544542 "tp" : 0 ,
545543 "fp" : 0 ,
546- "fn" : 0 ,
547- "tn" : 1 ,
544+ "fn" : 2 ,
545+ "tn" : 4 ,
548546 },
549547 "parameters" : {
550548 "score_threshold" : 0.5 ,
@@ -682,3 +680,106 @@ def test_filtering_remove_all(
682680 for v in example .values ():
683681 if isinstance (v , list ):
684682 assert len (v ) == 0
683+
684+
685+ def test_filtering_labels (
686+ loader : Loader ,
687+ classifications_animal_example : list [Classification ],
688+ tmp_path : Path ,
689+ ):
690+ loader .add_data (classifications_animal_example )
691+ evaluator = loader .finalize ()
692+
693+ assert evaluator ._index_to_label == {
694+ 0 : "bird" ,
695+ 1 : "dog" ,
696+ 2 : "cat" ,
697+ }
698+ assert evaluator .compute_precision_recall ()
699+ assert evaluator .compute_rocauc ()
700+ assert evaluator .compute_confusion_matrix ()
701+ assert evaluator .compute_examples ()
702+
703+ cm = evaluator .compute_confusion_matrix ()
704+ assert len (cm ) == 1
705+ assert cm [0 ].to_dict () == {
706+ "parameters" : {
707+ "hardmax" : True ,
708+ "score_threshold" : 0.0 ,
709+ },
710+ "type" : "ConfusionMatrix" ,
711+ "value" : {
712+ "confusion_matrix" : {
713+ "bird" : {
714+ "bird" : 1 ,
715+ "cat" : 1 ,
716+ "dog" : 1 ,
717+ },
718+ "cat" : {
719+ "bird" : 0 ,
720+ "cat" : 1 ,
721+ "dog" : 0 ,
722+ },
723+ "dog" : {
724+ "bird" : 0 ,
725+ "cat" : 2 ,
726+ "dog" : 0 ,
727+ },
728+ },
729+ "unmatched_ground_truths" : {
730+ "bird" : 0 ,
731+ "cat" : 0 ,
732+ "dog" : 0 ,
733+ },
734+ },
735+ }
736+
737+ filtered = evaluator .filter (
738+ groundtruths = pc .field ("gt_label" ).isin (["bird" , "dog" ]),
739+ predictions = pc .field ("pd_label" ).isin (["bird" , "dog" ]),
740+ path = tmp_path / "filter" ,
741+ )
742+
743+ assert filtered ._index_to_label == {
744+ 0 : "bird" ,
745+ 1 : "dog" ,
746+ 2 : "cat" ,
747+ }
748+ assert filtered .compute_precision_recall ()
749+ assert filtered .compute_rocauc ()
750+ assert filtered .compute_confusion_matrix ()
751+ assert filtered .compute_examples ()
752+
753+ cm = filtered .compute_confusion_matrix ()
754+ assert len (cm ) == 1
755+ assert cm [0 ].to_dict () == {
756+ "parameters" : {
757+ "hardmax" : True ,
758+ "score_threshold" : 0.0 ,
759+ },
760+ "type" : "ConfusionMatrix" ,
761+ "value" : {
762+ "confusion_matrix" : {
763+ "bird" : {
764+ "bird" : 1 ,
765+ "cat" : 0 ,
766+ "dog" : 1 ,
767+ },
768+ "cat" : {
769+ "bird" : 0 ,
770+ "cat" : 0 ,
771+ "dog" : 0 ,
772+ },
773+ "dog" : {
774+ "bird" : 0 ,
775+ "cat" : 0 ,
776+ "dog" : 0 ,
777+ },
778+ },
779+ "unmatched_ground_truths" : {
780+ "bird" : 1 ,
781+ "cat" : 0 ,
782+ "dog" : 2 ,
783+ },
784+ },
785+ }
0 commit comments