Skip to content

Commit e25cd0f

Browse files
authored
fix obj det label filtering (#867)
1 parent 67832a9 commit e25cd0f

File tree

2 files changed

+113
-4
lines changed

2 files changed

+113
-4
lines changed

src/valor_lite/object_detection/evaluator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,17 @@ def filter(
351351
pairs = np.column_stack([tbl[col].to_numpy() for col in columns])
352352

353353
n_pairs = pairs.shape[0]
354-
gt_ids = pairs[:, (0, 1)].astype(np.int64)
355-
pd_ids = pairs[:, (0, 2)].astype(np.int64)
354+
gt_ids = pairs[:, (0, 1, 3)].astype(np.int64)
355+
pd_ids = pairs[:, (0, 2, 4)].astype(np.int64)
356356

357357
if groundtruths is not None:
358358
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
359359
gt_tbl = tbl.filter(groundtruths)
360360
gt_pairs = np.column_stack(
361-
[gt_tbl[col].to_numpy() for col in ("datum_id", "gt_id")]
361+
[
362+
gt_tbl[col].to_numpy()
363+
for col in ("datum_id", "gt_id", "gt_label_id")
364+
]
362365
).astype(np.int64)
363366
for gt in np.unique(gt_pairs, axis=0):
364367
mask_valid_gt |= (gt_ids == gt).all(axis=1)
@@ -369,7 +372,10 @@ def filter(
369372
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
370373
pd_tbl = tbl.filter(predictions)
371374
pd_pairs = np.column_stack(
372-
[pd_tbl[col].to_numpy() for col in ("datum_id", "pd_id")]
375+
[
376+
pd_tbl[col].to_numpy()
377+
for col in ("datum_id", "pd_id", "pd_label_id")
378+
]
373379
).astype(np.int64)
374380
for pd in np.unique(pd_pairs, axis=0):
375381
mask_valid_pd |= (pd_ids == pd).all(axis=1)

tests/object_detection/test_filtering.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,3 +901,106 @@ def test_filtering_labels(
901901
},
902902
},
903903
}
904+
905+
906+
def test_filtering_labels_confusion_matrix(loader: Loader, tmp_path):
907+
gts = ["dog", "cat", "cat", "bird", "dog", "dog"]
908+
pds = [
909+
["dog", "cat", "bird"],
910+
["cat", "bird", "dog"],
911+
["dog", "cat", "bird"],
912+
["bird", "cat", "dog"],
913+
["bird", "dog", "cat"],
914+
["dog", "bird", "cat"],
915+
]
916+
scores = [0.6, 0.3, 0.1]
917+
918+
loader.add_bounding_boxes(
919+
[
920+
Detection(
921+
uid=f"datum{i}",
922+
groundtruths=[
923+
BoundingBox(
924+
uid=f"g_{i}_0",
925+
xmin=0,
926+
xmax=10,
927+
ymin=0,
928+
ymax=10,
929+
labels=[gt],
930+
)
931+
],
932+
predictions=[
933+
BoundingBox(
934+
uid=f"p_{i}_0",
935+
xmin=5,
936+
xmax=15,
937+
ymin=0,
938+
ymax=10,
939+
labels=pd,
940+
scores=scores,
941+
), # IOU=0.5
942+
BoundingBox(
943+
uid=f"p_{i}_1",
944+
xmin=2,
945+
xmax=12,
946+
ymin=0,
947+
ymax=10,
948+
labels=pd[1:] + pd[:1], # rotate labels
949+
scores=scores,
950+
), # IOU=0.8
951+
],
952+
metadata=None,
953+
)
954+
for i, (gt, pd) in enumerate(zip(gts, pds))
955+
]
956+
)
957+
evaluator = loader.finalize()
958+
959+
# test normal case
960+
metrics = evaluator.compute_confusion_matrix(
961+
iou_thresholds=[0.1], score_thresholds=[0.5]
962+
)
963+
assert len(metrics) == 1
964+
assert metrics[0].to_dict() == {
965+
"parameters": {
966+
"iou_threshold": 0.1,
967+
"score_threshold": 0.5,
968+
},
969+
"type": "ConfusionMatrix",
970+
"value": {
971+
"confusion_matrix": {
972+
"bird": {"bird": 0, "cat": 1, "dog": 0},
973+
"cat": {"bird": 1, "cat": 1, "dog": 0},
974+
"dog": {"bird": 1, "cat": 1, "dog": 1},
975+
},
976+
"unmatched_ground_truths": {"bird": 0, "cat": 0, "dog": 0},
977+
"unmatched_predictions": {"bird": 2, "cat": 1, "dog": 3},
978+
},
979+
}
980+
981+
# remove 'bird' from class labels
982+
filtered_evaluator = evaluator.filter(
983+
groundtruths=pc.field("gt_label").isin(["dog", "cat"]),
984+
predictions=pc.field("pd_label").isin(["dog", "cat"]),
985+
path=Path(tmp_path) / "filtered",
986+
)
987+
metrics = filtered_evaluator.compute_confusion_matrix(
988+
iou_thresholds=[0.1], score_thresholds=[0.5]
989+
)
990+
assert len(metrics) == 1
991+
assert metrics[0].to_dict() == {
992+
"parameters": {
993+
"iou_threshold": 0.1,
994+
"score_threshold": 0.5,
995+
},
996+
"type": "ConfusionMatrix",
997+
"value": {
998+
"confusion_matrix": {
999+
"bird": {"bird": 0, "cat": 0, "dog": 0},
1000+
"cat": {"bird": 0, "cat": 2, "dog": 0},
1001+
"dog": {"bird": 0, "cat": 1, "dog": 2},
1002+
},
1003+
"unmatched_ground_truths": {"bird": 0, "cat": 0, "dog": 0},
1004+
"unmatched_predictions": {"bird": 0, "cat": 1, "dog": 2},
1005+
},
1006+
}

0 commit comments

Comments
 (0)