@@ -689,3 +689,215 @@ def test_get_info(
689689 assert info .number_of_labels == 2
690690 assert info .number_of_groundtruth_annotations == 6
691691 assert info .number_of_prediction_annotations == 2
692+
693+
694+ def test_filtering_labels (
695+ loader : Loader , torchmetrics_detections : list [Detection ], tmp_path : Path
696+ ):
697+ loader .add_bounding_boxes (torchmetrics_detections )
698+ evaluator = loader .finalize ()
699+
700+ scores = [0.1 ]
701+ ious = [0.1 ]
702+
703+ assert evaluator ._index_to_label == {
704+ 0 : "4" ,
705+ 1 : "2" ,
706+ 2 : "3" ,
707+ 3 : "1" ,
708+ 4 : "0" ,
709+ 5 : "49" ,
710+ }
711+ assert evaluator .compute_precision_recall (
712+ score_thresholds = scores , iou_thresholds = ious
713+ )
714+ assert evaluator .compute_confusion_matrix (
715+ score_thresholds = scores , iou_thresholds = ious
716+ )
717+ assert evaluator .compute_examples (
718+ score_thresholds = scores , iou_thresholds = ious
719+ )
720+
721+ cm = evaluator .compute_confusion_matrix (
722+ score_thresholds = scores , iou_thresholds = ious
723+ )
724+ assert len (cm ) == 1
725+ assert cm [0 ].to_dict () == {
726+ "parameters" : {
727+ "iou_threshold" : 0.1 ,
728+ "score_threshold" : 0.1 ,
729+ },
730+ "type" : "ConfusionMatrix" ,
731+ "value" : {
732+ "confusion_matrix" : {
733+ "0" : {
734+ "0" : 5 ,
735+ "1" : 0 ,
736+ "2" : 0 ,
737+ "3" : 0 ,
738+ "4" : 0 ,
739+ "49" : 0 ,
740+ },
741+ "1" : {
742+ "0" : 0 ,
743+ "1" : 1 ,
744+ "2" : 0 ,
745+ "3" : 0 ,
746+ "4" : 0 ,
747+ "49" : 0 ,
748+ },
749+ "2" : {
750+ "0" : 0 ,
751+ "1" : 0 ,
752+ "2" : 1 ,
753+ "3" : 1 ,
754+ "4" : 0 ,
755+ "49" : 0 ,
756+ },
757+ "3" : {
758+ "0" : 0 ,
759+ "1" : 0 ,
760+ "2" : 0 ,
761+ "3" : 0 ,
762+ "4" : 0 ,
763+ "49" : 0 ,
764+ },
765+ "4" : {
766+ "0" : 0 ,
767+ "1" : 0 ,
768+ "2" : 0 ,
769+ "3" : 0 ,
770+ "4" : 2 ,
771+ "49" : 0 ,
772+ },
773+ "49" : {
774+ "0" : 0 ,
775+ "1" : 0 ,
776+ "2" : 0 ,
777+ "3" : 0 ,
778+ "4" : 0 ,
779+ "49" : 9 ,
780+ },
781+ },
782+ "unmatched_ground_truths" : {
783+ "0" : 0 ,
784+ "1" : 0 ,
785+ "2" : 0 ,
786+ "3" : 0 ,
787+ "4" : 0 ,
788+ "49" : 1 ,
789+ },
790+ "unmatched_predictions" : {
791+ "0" : 0 ,
792+ "1" : 0 ,
793+ "2" : 0 ,
794+ "3" : 0 ,
795+ "4" : 0 ,
796+ "49" : 0 ,
797+ },
798+ },
799+ }
800+
801+ filtered = evaluator .filter (
802+ groundtruths = pc .field ("gt_label" ).isin (["2" , "3" ]),
803+ predictions = pc .field ("pd_label" ).isin (["2" , "3" ]),
804+ path = tmp_path / "filter" ,
805+ )
806+
807+ assert filtered ._index_to_label == {
808+ 0 : "4" ,
809+ 1 : "2" ,
810+ 2 : "3" ,
811+ 3 : "1" ,
812+ 4 : "0" ,
813+ 5 : "49" ,
814+ }
815+ assert evaluator .compute_precision_recall (
816+ score_thresholds = scores , iou_thresholds = ious
817+ )
818+ assert evaluator .compute_confusion_matrix (
819+ score_thresholds = scores , iou_thresholds = ious
820+ )
821+ assert evaluator .compute_examples (
822+ score_thresholds = scores , iou_thresholds = ious
823+ )
824+
825+ cm = filtered .compute_confusion_matrix (
826+ score_thresholds = scores , iou_thresholds = ious
827+ )
828+ assert len (cm ) == 1
829+ assert cm [0 ].to_dict () == {
830+ "parameters" : {
831+ "iou_threshold" : 0.1 ,
832+ "score_threshold" : 0.1 ,
833+ },
834+ "type" : "ConfusionMatrix" ,
835+ "value" : {
836+ "confusion_matrix" : {
837+ "0" : {
838+ "0" : 0 ,
839+ "1" : 0 ,
840+ "2" : 0 ,
841+ "3" : 0 ,
842+ "4" : 0 ,
843+ "49" : 0 ,
844+ },
845+ "1" : {
846+ "0" : 0 ,
847+ "1" : 0 ,
848+ "2" : 0 ,
849+ "3" : 0 ,
850+ "4" : 0 ,
851+ "49" : 0 ,
852+ },
853+ "2" : {
854+ "0" : 0 ,
855+ "1" : 0 ,
856+ "2" : 1 ,
857+ "3" : 1 ,
858+ "4" : 0 ,
859+ "49" : 0 ,
860+ },
861+ "3" : {
862+ "0" : 0 ,
863+ "1" : 0 ,
864+ "2" : 0 ,
865+ "3" : 0 ,
866+ "4" : 0 ,
867+ "49" : 0 ,
868+ },
869+ "4" : {
870+ "0" : 0 ,
871+ "1" : 0 ,
872+ "2" : 0 ,
873+ "3" : 0 ,
874+ "4" : 0 ,
875+ "49" : 0 ,
876+ },
877+ "49" : {
878+ "0" : 0 ,
879+ "1" : 0 ,
880+ "2" : 0 ,
881+ "3" : 0 ,
882+ "4" : 0 ,
883+ "49" : 0 ,
884+ },
885+ },
886+ "unmatched_ground_truths" : {
887+ "0" : 0 ,
888+ "1" : 0 ,
889+ "2" : 0 ,
890+ "3" : 0 ,
891+ "4" : 0 ,
892+ "49" : 0 ,
893+ },
894+ "unmatched_predictions" : {
895+ "0" : 0 ,
896+ "1" : 0 ,
897+ "2" : 0 ,
898+ "3" : 0 ,
899+ "4" : 0 ,
900+ "49" : 0 ,
901+ },
902+ },
903+ }
0 commit comments