@@ -901,3 +901,106 @@ def test_filtering_labels(
901901 },
902902 },
903903 }
904+
905+
906+ def test_filtering_labels_confusion_matrix (loader : Loader , tmp_path ):
907+ gts = ["dog" , "cat" , "cat" , "bird" , "dog" , "dog" ]
908+ pds = [
909+ ["dog" , "cat" , "bird" ],
910+ ["cat" , "bird" , "dog" ],
911+ ["dog" , "cat" , "bird" ],
912+ ["bird" , "cat" , "dog" ],
913+ ["bird" , "dog" , "cat" ],
914+ ["dog" , "bird" , "cat" ],
915+ ]
916+ scores = [0.6 , 0.3 , 0.1 ]
917+
918+ loader .add_bounding_boxes (
919+ [
920+ Detection (
921+ uid = f"datum{ i } " ,
922+ groundtruths = [
923+ BoundingBox (
924+ uid = f"g_{ i } _0" ,
925+ xmin = 0 ,
926+ xmax = 10 ,
927+ ymin = 0 ,
928+ ymax = 10 ,
929+ labels = [gt ],
930+ )
931+ ],
932+ predictions = [
933+ BoundingBox (
934+ uid = f"p_{ i } _0" ,
935+ xmin = 5 ,
936+ xmax = 15 ,
937+ ymin = 0 ,
938+ ymax = 10 ,
939+ labels = pd ,
940+ scores = scores ,
941+ ), # IOU=0.5
942+ BoundingBox (
943+ uid = f"p_{ i } _1" ,
944+ xmin = 2 ,
945+ xmax = 12 ,
946+ ymin = 0 ,
947+ ymax = 10 ,
948+ labels = pd [1 :] + pd [:1 ], # rotate labels
949+ scores = scores ,
950+ ), # IOU=0.8
951+ ],
952+ metadata = None ,
953+ )
954+ for i , (gt , pd ) in enumerate (zip (gts , pds ))
955+ ]
956+ )
957+ evaluator = loader .finalize ()
958+
959+ # test normal case
960+ metrics = evaluator .compute_confusion_matrix (
961+ iou_thresholds = [0.1 ], score_thresholds = [0.5 ]
962+ )
963+ assert len (metrics ) == 1
964+ assert metrics [0 ].to_dict () == {
965+ "parameters" : {
966+ "iou_threshold" : 0.1 ,
967+ "score_threshold" : 0.5 ,
968+ },
969+ "type" : "ConfusionMatrix" ,
970+ "value" : {
971+ "confusion_matrix" : {
972+ "bird" : {"bird" : 0 , "cat" : 1 , "dog" : 0 },
973+ "cat" : {"bird" : 1 , "cat" : 1 , "dog" : 0 },
974+ "dog" : {"bird" : 1 , "cat" : 1 , "dog" : 1 },
975+ },
976+ "unmatched_ground_truths" : {"bird" : 0 , "cat" : 0 , "dog" : 0 },
977+ "unmatched_predictions" : {"bird" : 2 , "cat" : 1 , "dog" : 3 },
978+ },
979+ }
980+
981+ # remove 'bird' from class labels
982+ filtered_evaluator = evaluator .filter (
983+ groundtruths = pc .field ("gt_label" ).isin (["dog" , "cat" ]),
984+ predictions = pc .field ("pd_label" ).isin (["dog" , "cat" ]),
985+ path = Path (tmp_path ) / "filtered" ,
986+ )
987+ metrics = filtered_evaluator .compute_confusion_matrix (
988+ iou_thresholds = [0.1 ], score_thresholds = [0.5 ]
989+ )
990+ assert len (metrics ) == 1
991+ assert metrics [0 ].to_dict () == {
992+ "parameters" : {
993+ "iou_threshold" : 0.1 ,
994+ "score_threshold" : 0.5 ,
995+ },
996+ "type" : "ConfusionMatrix" ,
997+ "value" : {
998+ "confusion_matrix" : {
999+ "bird" : {"bird" : 0 , "cat" : 0 , "dog" : 0 },
1000+ "cat" : {"bird" : 0 , "cat" : 2 , "dog" : 0 },
1001+ "dog" : {"bird" : 0 , "cat" : 1 , "dog" : 2 },
1002+ },
1003+ "unmatched_ground_truths" : {"bird" : 0 , "cat" : 0 , "dog" : 0 },
1004+ "unmatched_predictions" : {"bird" : 0 , "cat" : 1 , "dog" : 2 },
1005+ },
1006+ }
0 commit comments