Skip to content

Commit e830013

Browse files
authored
Remove ability to pass datum filter to compute_rocauc (#863)
1 parent bdd6444 commit e830013

File tree

4 files changed

+100
-12
lines changed

4 files changed

+100
-12
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
.PHONY: install pre-commit test help
1+
.PHONY: install lint test help
22

33
install:
44
@echo "Installing from source..."
55
pip install -e src/[dev]
66

7-
pre-commit:
7+
lint:
88
@echo "Running pre-commit..."
99
pre-commit install
1010
pre-commit run --all
@@ -19,6 +19,6 @@ test:
1919
help:
2020
@echo "Available targets:"
2121
@echo " install Install from source with developer tools."
22-
@echo " pre-commit Run pre-commit."
22+
@echo " lint Run pre-commit."
2323
@echo " test Run tests."
2424
@echo " help Show this help message."

src/valor_lite/classification/evaluator.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -545,16 +545,12 @@ def iterate_values_with_tables(self, datums: pc.Expression | None = None):
545545
matches = tbl["match"].to_numpy()
546546
yield ids, scores, winners, matches, tbl
547547

548-
def compute_rocauc(
549-
self, datums: pc.Expression | None = None
550-
) -> dict[MetricType, list[Metric]]:
548+
def compute_rocauc(self) -> dict[MetricType, list[Metric]]:
551549
"""
552550
Compute ROCAUC.
553551
554-
Parameters
555-
----------
556-
datums : pyarrow.compute.Expression, optional
557-
Option to filter datums by an expression.
552+
This function does not support direct filtering. To perform evaluation over a filtered
553+
set you must first create a new evaluator using `Evaluator.filter`.
558554
559555
Returns
560556
-------
@@ -567,7 +563,6 @@ def compute_rocauc(
567563
label_counts = extract_groundtruth_count_per_label(
568564
reader=self._reader,
569565
number_of_labels=len(self._index_to_label),
570-
datums=datums,
571566
)
572567

573568
prev = np.zeros((n_labels, 2), dtype=np.uint64)
@@ -577,7 +572,6 @@ def compute_rocauc(
577572
"cumulative_fp",
578573
"cumulative_tp",
579574
],
580-
filter=datums,
581575
):
582576
rocauc, prev = compute_rocauc(
583577
rocauc=rocauc,

tests/classification/test_filtering.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,49 @@ def test_filtering_six_classifications_inline(
636636
assert m in expected_metrics
637637
for m in expected_metrics:
638638
assert m in actual_metrics
639+
640+
641+
def test_filtering_remove_all(
642+
loader: Loader,
643+
six_classifications: list[Classification],
644+
tmp_path: Path,
645+
):
646+
647+
loader.add_data(six_classifications)
648+
evaluator = loader.finalize()
649+
650+
datums = pc.field("datum_uid") == "does_not_exist"
651+
652+
# test evaluation
653+
base_metrics = evaluator.compute_precision_recall(datums=datums)
654+
with pytest.raises(TypeError) as e:
655+
evaluator.compute_rocauc(datums=datums) # type: ignore - testing
656+
assert "unexpected keyword" in str(e)
657+
confusion = evaluator.compute_confusion_matrix(datums=datums)
658+
examples = evaluator.compute_examples(datums=datums)
659+
660+
for k, mlist in base_metrics.items():
661+
for m in mlist:
662+
if k == MetricType.Counts:
663+
assert isinstance(m.value, dict)
664+
for v in m.value.values():
665+
assert isinstance(v, int)
666+
assert v >= 0
667+
else:
668+
assert isinstance(m.value, float)
669+
assert m.value <= 1.0
670+
assert m.value >= 0.0
671+
for cm in confusion:
672+
assert isinstance(cm.value, dict)
673+
for row in cm.value["confusion_matrix"].values():
674+
for v in row.values():
675+
assert isinstance(v, int)
676+
assert v >= 0
677+
for v in cm.value["unmatched_ground_truths"].values():
678+
assert isinstance(v, int)
679+
assert v >= 0
680+
for example in examples:
681+
assert isinstance(example, dict)
682+
for v in example.values():
683+
if isinstance(v, list):
684+
assert len(v) == 0

tests/classification/test_rocauc.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,51 @@ def test_rocauc_with_tabular_example(
267267
assert m in expected_metrics
268268
for m in expected_metrics:
269269
assert m in actual_metrics
270+
271+
272+
def test_rocauc_single_classification(loader: Loader):
273+
data = [
274+
Classification(
275+
uid="uid",
276+
groundtruth="dog",
277+
predictions=["dog", "cat"],
278+
scores=[1.0, 0.0],
279+
)
280+
]
281+
loader.add_data(data)
282+
evaluator = loader.finalize()
283+
284+
metrics = evaluator.compute_rocauc()
285+
286+
# test ROCAUC
287+
actual_metrics = [m.to_dict() for m in metrics[MetricType.ROCAUC]]
288+
expected_metrics = [
289+
{
290+
"type": "ROCAUC",
291+
"value": 0.0,
292+
"parameters": {
293+
"label": "dog",
294+
},
295+
},
296+
{
297+
"type": "ROCAUC",
298+
"value": 0.0,
299+
"parameters": {
300+
"label": "cat",
301+
},
302+
},
303+
]
304+
for m in actual_metrics:
305+
assert m in expected_metrics
306+
for m in expected_metrics:
307+
assert m in actual_metrics
308+
309+
# test mROCAUC
310+
actual_metrics = [m.to_dict() for m in metrics[MetricType.mROCAUC]]
311+
expected_metrics = [
312+
{"type": "mROCAUC", "value": 0.0, "parameters": {}},
313+
]
314+
for m in actual_metrics:
315+
assert m in expected_metrics
316+
for m in expected_metrics:
317+
assert m in actual_metrics

0 commit comments

Comments
 (0)