Skip to content

Commit 0c7e5a4

Browse files
committed
add pagination to examples
1 parent e830013 commit 0c7e5a4

File tree

8 files changed

+635
-46
lines changed

8 files changed

+635
-46
lines changed

src/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ docs = [
3333
]
3434
test = ["pytest", "coverage", "pre-commit"]
3535
benchmark = ["requests"]
36-
dev = ["valor-lite[nlp, openai, mistral, benchmark, test, docs]"]
36+
dev = ["valor-lite[nlp, openai, mistral, benchmark, test, docs]", "pyarrow-stubs"]
3737

3838
[tool.black]
3939
line-length = 79

src/valor_lite/cache/compute.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import heapq
22
import tempfile
33
from pathlib import Path
4-
from typing import Callable
4+
from typing import Callable, Generator
55

66
import pyarrow as pa
7+
import pyarrow.compute as pc
78

89
from valor_lite.cache.ephemeral import MemoryCacheReader, MemoryCacheWriter
910
from valor_lite.cache.persistent import FileCacheReader, FileCacheWriter
@@ -152,3 +153,56 @@ def sort(
152153
columns=columns,
153154
table_sort_override=table_sort_override,
154155
)
156+
157+
158+
def paginate_index(
159+
source: MemoryCacheReader | FileCacheReader,
160+
column_key: str,
161+
modifier: pc.Expression | None = None,
162+
limit: int | None = None,
163+
offset: int = 0,
164+
) -> Generator[pa.Table, None, None]:
165+
"""
166+
Create a filter that performs a pagination operation on an index.
167+
168+
Note this function expects unqiue keys to be fragment-aligned and in ascending order.
169+
"""
170+
total = source.count_rows()
171+
limit = limit if limit else total
172+
173+
# pagination broader than data scope
174+
if offset == 0 and limit >= total:
175+
for tbl in source.iterate_tables(filter=modifier):
176+
yield tbl
177+
return
178+
elif offset >= total:
179+
return
180+
181+
curr_idx = 0
182+
for tbl in source.iterate_tables(filter=modifier):
183+
if tbl.num_rows == 0:
184+
continue
185+
186+
unique_values = pc.unique(tbl[column_key]).sort() # type: ignore[reportAttributeAccessIssue]
187+
n_unique = len(unique_values)
188+
prev_idx = curr_idx
189+
curr_idx += n_unique
190+
191+
# check for page overlap
192+
if curr_idx <= offset:
193+
continue
194+
elif prev_idx >= (offset + limit):
195+
return
196+
197+
# apply any pagination conditions
198+
condition = pc.scalar(True)
199+
if prev_idx < offset and curr_idx > offset:
200+
condition &= (
201+
pc.field(column_key) >= unique_values[offset - prev_idx]
202+
)
203+
if prev_idx < (offset + limit) and curr_idx > (offset + limit):
204+
condition &= (
205+
pc.field(column_key) < unique_values[offset + limit - prev_idx]
206+
)
207+
208+
yield tbl.filter(condition)

src/valor_lite/classification/evaluator.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -528,23 +528,6 @@ def iterate_values(self, datums: pc.Expression | None = None):
528528
matches = tbl["match"].to_numpy()
529529
yield ids, scores, winners, matches
530530

531-
def iterate_values_with_tables(self, datums: pc.Expression | None = None):
532-
for tbl in self._reader.iterate_tables(filter=datums):
533-
ids = np.column_stack(
534-
[
535-
tbl[col].to_numpy()
536-
for col in [
537-
"datum_id",
538-
"gt_label_id",
539-
"pd_label_id",
540-
]
541-
]
542-
)
543-
scores = tbl["pd_score"].to_numpy()
544-
winners = tbl["pd_winner"].to_numpy()
545-
matches = tbl["match"].to_numpy()
546-
yield ids, scores, winners, matches, tbl
547-
548531
def compute_rocauc(self) -> dict[MetricType, list[Metric]]:
549532
"""
550533
Compute ROCAUC.
@@ -723,6 +706,8 @@ def compute_examples(
723706
score_thresholds: list[float] = [0.0],
724707
hardmax: bool = True,
725708
datums: pc.Expression | None = None,
709+
limit: int | None = None,
710+
offset: int = 0,
726711
) -> list[Metric]:
727712
"""
728713
Compute examples per datum.
@@ -737,6 +722,10 @@ def compute_examples(
737722
Toggles whether a hardmax is applied to predictions.
738723
datums : pyarrow.compute.Expression, optional
739724
Option to filter datums by an expression.
725+
limit : int, optional
726+
Option to set a limit to the number of returned datum examples.
727+
offset : int, default=0
728+
Option to offset where examples are being created in the datum index.
740729
741730
Returns
742731
-------
@@ -747,16 +736,29 @@ def compute_examples(
747736
raise ValueError("At least one score threshold must be passed.")
748737

749738
metrics = []
750-
for (
751-
ids,
752-
scores,
753-
winners,
754-
_,
755-
tbl,
756-
) in self.iterate_values_with_tables(datums=datums):
757-
if ids.size == 0:
739+
for tbl in compute.paginate_index(
740+
source=self._reader,
741+
column_key="datum_id",
742+
modifier=datums,
743+
limit=limit,
744+
offset=offset,
745+
):
746+
if tbl.num_rows == 0:
758747
continue
759748

749+
ids = np.column_stack(
750+
[
751+
tbl[col].to_numpy()
752+
for col in [
753+
"datum_id",
754+
"gt_label_id",
755+
"pd_label_id",
756+
]
757+
]
758+
)
759+
scores = tbl["pd_score"].to_numpy()
760+
winners = tbl["pd_winner"].to_numpy()
761+
760762
# extract external identifiers
761763
index_to_datum_id = create_mapping(
762764
tbl, ids, 0, "datum_id", "datum_uid"
@@ -829,16 +831,23 @@ def compute_confusion_matrix_with_examples(
829831
)
830832
for score_idx, score_thresh in enumerate(score_thresholds)
831833
}
832-
for (
833-
ids,
834-
scores,
835-
winners,
836-
_,
837-
tbl,
838-
) in self.iterate_values_with_tables(datums=datums):
839-
if ids.size == 0:
834+
for tbl in self._reader.iterate_tables(filter=datums):
835+
if tbl.num_rows == 0:
840836
continue
841837

838+
ids = np.column_stack(
839+
[
840+
tbl[col].to_numpy()
841+
for col in [
842+
"datum_id",
843+
"gt_label_id",
844+
"pd_label_id",
845+
]
846+
]
847+
)
848+
scores = tbl["pd_score"].to_numpy()
849+
winners = tbl["pd_winner"].to_numpy()
850+
842851
# extract external identifiers
843852
index_to_datum_id = create_mapping(
844853
tbl, ids, 0, "datum_id", "datum_uid"

src/valor_lite/object_detection/evaluator.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ def compute_examples(
600600
iou_thresholds: list[float],
601601
score_thresholds: list[float],
602602
datums: pc.Expression | None = None,
603+
limit: int | None = None,
604+
offset: int = 0,
603605
) -> list[Metric]:
604606
"""
605607
Computes examples at various thresholds.
@@ -614,6 +616,10 @@ def compute_examples(
614616
A list of score thresholds to compute metrics over.
615617
datums : pyarrow.compute.Expression, optional
616618
Option to filter datums by an expression.
619+
limit : int, optional
620+
Option to set a limit to the number of returned datum examples.
621+
offset : int, default=0
622+
Option to offset where examples are being created in the datum index.
617623
618624
Returns
619625
-------
@@ -626,11 +632,6 @@ def compute_examples(
626632
raise ValueError("At least one score threshold must be passed.")
627633

628634
metrics = []
629-
tbl_columns = [
630-
"datum_uid",
631-
"gt_uid",
632-
"pd_uid",
633-
]
634635
numeric_columns = [
635636
"datum_id",
636637
"gt_id",
@@ -640,14 +641,20 @@ def compute_examples(
640641
"iou",
641642
"pd_score",
642643
]
643-
for tbl, pairs in self._detailed_reader.iterate_tables_with_arrays(
644-
columns=tbl_columns + numeric_columns,
645-
numeric_columns=numeric_columns,
646-
filter=datums,
644+
for tbl in compute.paginate_index(
645+
source=self._detailed_reader,
646+
column_key="datum_id",
647+
modifier=datums,
648+
limit=limit,
649+
offset=offset,
647650
):
648-
if pairs.size == 0:
651+
if tbl.num_rows == 0:
649652
continue
650653

654+
pairs = np.column_stack(
655+
[tbl[col].to_numpy() for col in numeric_columns]
656+
)
657+
651658
index_to_datum_id = {}
652659
index_to_groundtruth_id = {}
653660
index_to_prediction_id = {}

tests/classification/test_examples.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,65 @@ def test_examples_without_hardmax_animal_example(
681681
assert m in expected_metrics
682682
for m in expected_metrics:
683683
assert m in actual_metrics
684+
685+
686+
def test_examples_with_color_example_paginated(
687+
loader: Loader,
688+
classifications_color_example: list[Classification],
689+
):
690+
691+
loader.add_data(classifications_color_example)
692+
evaluator = loader.finalize()
693+
694+
actual_metrics = evaluator.compute_examples(
695+
score_thresholds=[0.5],
696+
limit=3,
697+
offset=1,
698+
)
699+
700+
actual_metrics = [m.to_dict() for m in actual_metrics]
701+
expected_metrics = [
702+
{
703+
"type": "Examples",
704+
"value": {
705+
"datum_id": "uid1",
706+
"true_positives": [],
707+
"false_positives": ["blue"],
708+
"false_negatives": ["white"],
709+
},
710+
"parameters": {
711+
"score_threshold": 0.5,
712+
"hardmax": True,
713+
},
714+
},
715+
{
716+
"type": "Examples",
717+
"value": {
718+
"datum_id": "uid2",
719+
"true_positives": [],
720+
"false_positives": [],
721+
"false_negatives": ["red"],
722+
},
723+
"parameters": {
724+
"score_threshold": 0.5,
725+
"hardmax": True,
726+
},
727+
},
728+
{
729+
"type": "Examples",
730+
"value": {
731+
"datum_id": "uid3",
732+
"true_positives": [],
733+
"false_positives": ["white"],
734+
"false_negatives": ["blue"],
735+
},
736+
"parameters": {
737+
"score_threshold": 0.5,
738+
"hardmax": True,
739+
},
740+
},
741+
]
742+
for m in actual_metrics:
743+
assert m in expected_metrics
744+
for m in expected_metrics:
745+
assert m in actual_metrics

tests/common/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from pathlib import Path
2+
from typing import Callable
23

4+
import pyarrow as pa
35
import pytest
46

57
from valor_lite.cache.ephemeral import MemoryCacheWriter
@@ -20,7 +22,9 @@
2022
"in-memory_small_chunks",
2123
],
2224
)
23-
def create_writer(request, tmp_path: Path):
25+
def create_writer(
26+
request, tmp_path: Path
27+
) -> Callable[[pa.Schema], MemoryCacheWriter | FileCacheWriter]:
2428
file_type, batch_size, rows_per_file = request.param
2529
match file_type:
2630
case "memory":
@@ -35,3 +39,5 @@ def create_writer(request, tmp_path: Path):
3539
batch_size=batch_size,
3640
rows_per_file=rows_per_file,
3741
)
42+
case unknown:
43+
raise RuntimeError(unknown)

0 commit comments

Comments
 (0)