Skip to content

Commit 10574f2

Browse files
committed
add obj det test
1 parent 4d09848 commit 10574f2

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

tests/object_detection/test_filtering.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,3 +689,215 @@ def test_get_info(
689689
assert info.number_of_labels == 2
690690
assert info.number_of_groundtruth_annotations == 6
691691
assert info.number_of_prediction_annotations == 2
692+
693+
694+
def test_filtering_labels(
695+
loader: Loader, torchmetrics_detections: list[Detection], tmp_path: Path
696+
):
697+
loader.add_bounding_boxes(torchmetrics_detections)
698+
evaluator = loader.finalize()
699+
700+
scores = [0.1]
701+
ious = [0.1]
702+
703+
assert evaluator._index_to_label == {
704+
0: "4",
705+
1: "2",
706+
2: "3",
707+
3: "1",
708+
4: "0",
709+
5: "49",
710+
}
711+
assert evaluator.compute_precision_recall(
712+
score_thresholds=scores, iou_thresholds=ious
713+
)
714+
assert evaluator.compute_confusion_matrix(
715+
score_thresholds=scores, iou_thresholds=ious
716+
)
717+
assert evaluator.compute_examples(
718+
score_thresholds=scores, iou_thresholds=ious
719+
)
720+
721+
cm = evaluator.compute_confusion_matrix(
722+
score_thresholds=scores, iou_thresholds=ious
723+
)
724+
assert len(cm) == 1
725+
assert cm[0].to_dict() == {
726+
"parameters": {
727+
"iou_threshold": 0.1,
728+
"score_threshold": 0.1,
729+
},
730+
"type": "ConfusionMatrix",
731+
"value": {
732+
"confusion_matrix": {
733+
"0": {
734+
"0": 5,
735+
"1": 0,
736+
"2": 0,
737+
"3": 0,
738+
"4": 0,
739+
"49": 0,
740+
},
741+
"1": {
742+
"0": 0,
743+
"1": 1,
744+
"2": 0,
745+
"3": 0,
746+
"4": 0,
747+
"49": 0,
748+
},
749+
"2": {
750+
"0": 0,
751+
"1": 0,
752+
"2": 1,
753+
"3": 1,
754+
"4": 0,
755+
"49": 0,
756+
},
757+
"3": {
758+
"0": 0,
759+
"1": 0,
760+
"2": 0,
761+
"3": 0,
762+
"4": 0,
763+
"49": 0,
764+
},
765+
"4": {
766+
"0": 0,
767+
"1": 0,
768+
"2": 0,
769+
"3": 0,
770+
"4": 2,
771+
"49": 0,
772+
},
773+
"49": {
774+
"0": 0,
775+
"1": 0,
776+
"2": 0,
777+
"3": 0,
778+
"4": 0,
779+
"49": 9,
780+
},
781+
},
782+
"unmatched_ground_truths": {
783+
"0": 0,
784+
"1": 0,
785+
"2": 0,
786+
"3": 0,
787+
"4": 0,
788+
"49": 1,
789+
},
790+
"unmatched_predictions": {
791+
"0": 0,
792+
"1": 0,
793+
"2": 0,
794+
"3": 0,
795+
"4": 0,
796+
"49": 0,
797+
},
798+
},
799+
}
800+
801+
filtered = evaluator.filter(
802+
groundtruths=pc.field("gt_label").isin(["2", "3"]),
803+
predictions=pc.field("pd_label").isin(["2", "3"]),
804+
path=tmp_path / "filter",
805+
)
806+
807+
assert filtered._index_to_label == {
808+
0: "4",
809+
1: "2",
810+
2: "3",
811+
3: "1",
812+
4: "0",
813+
5: "49",
814+
}
815+
assert evaluator.compute_precision_recall(
816+
score_thresholds=scores, iou_thresholds=ious
817+
)
818+
assert evaluator.compute_confusion_matrix(
819+
score_thresholds=scores, iou_thresholds=ious
820+
)
821+
assert evaluator.compute_examples(
822+
score_thresholds=scores, iou_thresholds=ious
823+
)
824+
825+
cm = filtered.compute_confusion_matrix(
826+
score_thresholds=scores, iou_thresholds=ious
827+
)
828+
assert len(cm) == 1
829+
assert cm[0].to_dict() == {
830+
"parameters": {
831+
"iou_threshold": 0.1,
832+
"score_threshold": 0.1,
833+
},
834+
"type": "ConfusionMatrix",
835+
"value": {
836+
"confusion_matrix": {
837+
"0": {
838+
"0": 0,
839+
"1": 0,
840+
"2": 0,
841+
"3": 0,
842+
"4": 0,
843+
"49": 0,
844+
},
845+
"1": {
846+
"0": 0,
847+
"1": 0,
848+
"2": 0,
849+
"3": 0,
850+
"4": 0,
851+
"49": 0,
852+
},
853+
"2": {
854+
"0": 0,
855+
"1": 0,
856+
"2": 1,
857+
"3": 1,
858+
"4": 0,
859+
"49": 0,
860+
},
861+
"3": {
862+
"0": 0,
863+
"1": 0,
864+
"2": 0,
865+
"3": 0,
866+
"4": 0,
867+
"49": 0,
868+
},
869+
"4": {
870+
"0": 0,
871+
"1": 0,
872+
"2": 0,
873+
"3": 0,
874+
"4": 0,
875+
"49": 0,
876+
},
877+
"49": {
878+
"0": 0,
879+
"1": 0,
880+
"2": 0,
881+
"3": 0,
882+
"4": 0,
883+
"49": 0,
884+
},
885+
},
886+
"unmatched_ground_truths": {
887+
"0": 0,
888+
"1": 0,
889+
"2": 0,
890+
"3": 0,
891+
"4": 0,
892+
"49": 0,
893+
},
894+
"unmatched_predictions": {
895+
"0": 0,
896+
"1": 0,
897+
"2": 0,
898+
"3": 0,
899+
"4": 0,
900+
"49": 0,
901+
},
902+
},
903+
}

0 commit comments

Comments
 (0)