Skip to content

Commit 2c55065

Browse files
committed
wrapped DataType
1 parent 3d5e297 commit 2c55065

File tree

13 files changed

+235
-161
lines changed

13 files changed

+235
-161
lines changed

src/valor_lite/classification/evaluator.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
unpack_rocauc,
5151
)
5252
from valor_lite.exceptions import EmptyCacheError
53+
from valor_lite.filtering import DataType, Expression
5354

5455

5556
class Builder:
@@ -58,7 +59,9 @@ def __init__(
5859
writer: MemoryCacheWriter | FileCacheWriter,
5960
roc_curve_writer: MemoryCacheWriter | FileCacheWriter,
6061
intermediate_writer: MemoryCacheWriter | FileCacheWriter,
61-
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
62+
metadata_fields: list[tuple[str, DataType]]
63+
| list[tuple[str, str]]
64+
| None = None,
6265
):
6366
self._writer = writer
6467
self._roc_curve_writer = roc_curve_writer
@@ -69,7 +72,9 @@ def __init__(
6972
def in_memory(
7073
cls,
7174
batch_size: int = 10_000,
72-
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
75+
metadata_fields: list[tuple[str, DataType]]
76+
| list[tuple[str, str]]
77+
| None = None,
7378
):
7479
"""
7580
Create an in-memory evaluator cache.
@@ -78,7 +83,7 @@ def in_memory(
7883
----------
7984
batch_size : int, default=10_000
8085
The target number of rows to buffer before writing to the cache. Defaults to 10_000.
81-
metadata_fields : list[tuple[str, str | pa.DataType]], optional
86+
metadata_fields : list[tuple[str, str | DataType]], optional
8287
Optional metadata field definitions.
8388
"""
8489
writer = MemoryCacheWriter.create(
@@ -107,7 +112,9 @@ def persistent(
107112
batch_size: int = 10_000,
108113
rows_per_file: int = 100_000,
109114
compression: str = "snappy",
110-
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
115+
metadata_fields: list[tuple[str, DataType]]
116+
| list[tuple[str, str]]
117+
| None = None,
111118
):
112119
"""
113120
Create a persistent file-based evaluator cache.
@@ -122,7 +129,7 @@ def persistent(
122129
Sets the maximum number of rows per file. This may be exceeded as files are datum aligned.
123130
compression : str, default="snappy"
124131
Sets the pyarrow compression method.
125-
metadata_fields : list[tuple[str, str | pa.DataType]], optional
132+
metadata_fields : list[tuple[str, str | DataType]], optional
126133
Optionally sets metadata description for use in filtering.
127134
"""
128135
path = Path(path)
@@ -324,20 +331,26 @@ def __init__(
324331
reader: MemoryCacheReader | FileCacheReader,
325332
roc_curve_reader: MemoryCacheReader | FileCacheReader,
326333
index_to_label: dict[int, str],
327-
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
334+
metadata_fields: list[tuple[str, str]]
335+
| list[tuple[str, DataType]]
336+
| None = None,
328337
):
329338
self._reader = reader
330339
self._roc_curve_reader = roc_curve_reader
331340
self._index_to_label = index_to_label
332-
self._metadata_fields = metadata_fields
341+
self._metadata_fields = (
342+
[(name, str(dtype)) for name, dtype in metadata_fields]
343+
if metadata_fields
344+
else None
345+
)
333346

334347
@property
335348
def info(self) -> EvaluatorInfo:
336349
return self.get_info()
337350

338351
def get_info(
339352
self,
340-
datums: pc.Expression | None = None,
353+
datums: Expression | None = None,
341354
) -> EvaluatorInfo:
342355
info = EvaluatorInfo()
343356
info.metadata_fields = self._metadata_fields
@@ -403,21 +416,21 @@ def load(
403416

404417
def filter(
405418
self,
406-
datums: pc.Expression | None = None,
407-
groundtruths: pc.Expression | None = None,
408-
predictions: pc.Expression | None = None,
419+
datums: Expression | None = None,
420+
groundtruths: Expression | None = None,
421+
predictions: Expression | None = None,
409422
path: str | Path | None = None,
410423
) -> Evaluator:
411424
"""
412425
Filter evaluator cache.
413426
414427
Parameters
415428
----------
416-
datums : pc.Expression | None = None
429+
datums : Expression | None = None
417430
A filter expression used to filter datums.
418-
groundtruths : pc.Expression | None = None
431+
groundtruths : Expression | None = None
419432
A filter expression used to filter ground truth annotations.
420-
predictions : pc.Expression | None = None
433+
predictions : Expression | None = None
421434
A filter expression used to filter predictions.
422435
path : str | Path, optional
423436
Where to store the filtered cache if storing on disk.
@@ -447,7 +460,8 @@ def filter(
447460
metadata_fields=self.info.metadata_fields,
448461
)
449462

450-
for tbl in self._reader.iterate_tables(filter=datums):
463+
datum_filter = datums.to_arrow() if datums is not None else None
464+
for tbl in self._reader.iterate_tables(filter=datum_filter):
451465
columns = (
452466
"datum_id",
453467
"gt_label_id",
@@ -461,7 +475,7 @@ def filter(
461475

462476
if groundtruths is not None:
463477
mask_valid_gt = np.zeros(n_pairs, dtype=np.bool_)
464-
gt_tbl = tbl.filter(groundtruths)
478+
gt_tbl = tbl.filter(groundtruths.to_arrow())
465479
gt_pairs = np.column_stack(
466480
[
467481
gt_tbl[col].to_numpy()
@@ -475,7 +489,7 @@ def filter(
475489

476490
if predictions is not None:
477491
mask_valid_pd = np.zeros(n_pairs, dtype=np.bool_)
478-
pd_tbl = tbl.filter(predictions)
492+
pd_tbl = tbl.filter(predictions.to_arrow())
479493
pd_pairs = np.column_stack(
480494
[
481495
pd_tbl[col].to_numpy()
@@ -503,7 +517,7 @@ def filter(
503517

504518
return loader.finalize(index_to_label_override=self._index_to_label)
505519

506-
def iterate_values(self, datums: pc.Expression | None = None):
520+
def _iterate_values(self, datum_filter: pc.Expression | None = None):
507521
columns = [
508522
"datum_id",
509523
"gt_label_id",
@@ -512,7 +526,9 @@ def iterate_values(self, datums: pc.Expression | None = None):
512526
"pd_winner",
513527
"match",
514528
]
515-
for tbl in self._reader.iterate_tables(columns=columns, filter=datums):
529+
for tbl in self._reader.iterate_tables(
530+
columns=columns, filter=datum_filter
531+
):
516532
ids = np.column_stack(
517533
[
518534
tbl[col].to_numpy()
@@ -528,8 +544,10 @@ def iterate_values(self, datums: pc.Expression | None = None):
528544
matches = tbl["match"].to_numpy()
529545
yield ids, scores, winners, matches
530546

531-
def iterate_values_with_tables(self, datums: pc.Expression | None = None):
532-
for tbl in self._reader.iterate_tables(filter=datums):
547+
def _iterate_values_with_tables(
548+
self, datum_filter: pc.Expression | None = None
549+
):
550+
for tbl in self._reader.iterate_tables(filter=datum_filter):
533551
ids = np.column_stack(
534552
[
535553
tbl[col].to_numpy()
@@ -594,7 +612,7 @@ def compute_precision_recall(
594612
self,
595613
score_thresholds: list[float] = [0.0],
596614
hardmax: bool = True,
597-
datums: pc.Expression | None = None,
615+
datums: Expression | None = None,
598616
) -> dict[MetricType, list]:
599617
"""
600618
Performs an evaluation and returns metrics.
@@ -613,6 +631,7 @@ def compute_precision_recall(
613631
dict[MetricType, list]
614632
A dictionary mapping MetricType enumerations to lists of computed metrics.
615633
"""
634+
datum_filter = datums.to_arrow() if datums is not None else None
616635
if not score_thresholds:
617636
raise ValueError("At least one score threshold must be passed.")
618637

@@ -623,7 +642,9 @@ def compute_precision_recall(
623642
# intermediates
624643
counts = np.zeros((n_scores, n_labels, 4), dtype=np.uint64)
625644

626-
for ids, scores, winners, _ in self.iterate_values(datums=datums):
645+
for ids, scores, winners, _ in self._iterate_values(
646+
datum_filter=datum_filter
647+
):
627648
batch_counts = compute_counts(
628649
ids=ids,
629650
scores=scores,
@@ -654,7 +675,7 @@ def compute_confusion_matrix(
654675
self,
655676
score_thresholds: list[float] = [0.0],
656677
hardmax: bool = True,
657-
datums: pc.Expression | None = None,
678+
datums: Expression | None = None,
658679
) -> list[Metric]:
659680
"""
660681
Compute a confusion matrix.
@@ -673,6 +694,7 @@ def compute_confusion_matrix(
673694
list[Metric]
674695
A list of confusion matrices.
675696
"""
697+
datum_filter = datums.to_arrow() if datums is not None else None
676698
if not score_thresholds:
677699
raise ValueError("At least one score threshold must be passed.")
678700

@@ -684,8 +706,8 @@ def compute_confusion_matrix(
684706
unmatched_groundtruths = np.zeros(
685707
(n_scores, n_labels), dtype=np.uint64
686708
)
687-
for ids, scores, winners, matches in self.iterate_values(
688-
datums=datums
709+
for ids, scores, winners, matches in self._iterate_values(
710+
datum_filter=datum_filter
689711
):
690712
(
691713
mask_tp,
@@ -722,7 +744,7 @@ def compute_examples(
722744
self,
723745
score_thresholds: list[float] = [0.0],
724746
hardmax: bool = True,
725-
datums: pc.Expression | None = None,
747+
datums: Expression | None = None,
726748
) -> list[Metric]:
727749
"""
728750
Compute examples per datum.
@@ -743,6 +765,7 @@ def compute_examples(
743765
list[Metric]
744766
A list of confusion matrices.
745767
"""
768+
datum_filter = datums.to_arrow() if datums is not None else None
746769
if not score_thresholds:
747770
raise ValueError("At least one score threshold must be passed.")
748771

@@ -753,7 +776,7 @@ def compute_examples(
753776
winners,
754777
_,
755778
tbl,
756-
) in self.iterate_values_with_tables(datums=datums):
779+
) in self._iterate_values_with_tables(datum_filter=datum_filter):
757780
if ids.size == 0:
758781
continue
759782

@@ -795,7 +818,7 @@ def compute_confusion_matrix_with_examples(
795818
self,
796819
score_thresholds: list[float] = [0.0],
797820
hardmax: bool = True,
798-
datums: pc.Expression | None = None,
821+
datums: Expression | None = None,
799822
) -> list[Metric]:
800823
"""
801824
Compute confusion matrix with examples.
@@ -818,6 +841,7 @@ def compute_confusion_matrix_with_examples(
818841
list[Metric]
819842
A list of confusion matrices.
820843
"""
844+
datum_filter = datums.to_arrow() if datums is not None else None
821845
if not score_thresholds:
822846
raise ValueError("At least one score threshold must be passed.")
823847

@@ -835,7 +859,7 @@ def compute_confusion_matrix_with_examples(
835859
winners,
836860
_,
837861
tbl,
838-
) in self.iterate_values_with_tables(datums=datums):
862+
) in self._iterate_values_with_tables(datum_filter=datum_filter):
839863
if ids.size == 0:
840864
continue
841865

src/valor_lite/classification/shared.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numpy.typing import NDArray
88

99
from valor_lite.cache import FileCacheReader, MemoryCacheReader
10+
from valor_lite.filtering import DataType, Expression
1011

1112

1213
@dataclass
@@ -34,7 +35,9 @@ def generate_metadata_path(path: str | Path) -> Path:
3435

3536

3637
def generate_schema(
37-
metadata_fields: list[tuple[str, str | pa.DataType]] | None
38+
metadata_fields: list[tuple[str, DataType]]
39+
| list[tuple[str, str]]
40+
| None = None
3841
) -> pa.Schema:
3942
metadata_fields = metadata_fields if metadata_fields else []
4043
reserved_fields = [
@@ -59,7 +62,15 @@ def generate_schema(
5962
raise ValueError(
6063
f"metadata fields {conflicting} conflict with reserved fields"
6164
)
62-
return pa.schema(reserved_fields + metadata_fields)
65+
return pa.schema(
66+
reserved_fields
67+
+ [
68+
(name, dtype.to_arrow())
69+
if isinstance(dtype, DataType)
70+
else (name, dtype)
71+
for name, dtype in metadata_fields
72+
]
73+
)
6374

6475

6576
def generate_intermediate_schema() -> pa.Schema:
@@ -83,7 +94,9 @@ def generate_roc_curve_schema() -> pa.Schema:
8394

8495

8596
def encode_metadata_fields(
86-
metadata_fields: list[tuple[str, str | pa.DataType]] | None
97+
metadata_fields: list[tuple[str, DataType]]
98+
| list[tuple[str, str]]
99+
| None = None,
87100
) -> dict[str, str]:
88101
metadata_fields = metadata_fields if metadata_fields else []
89102
return {k: str(v) for k, v in metadata_fields}
@@ -133,10 +146,11 @@ def extract_labels(
133146

134147
def extract_counts(
135148
reader: MemoryCacheReader | FileCacheReader,
136-
datums: pc.Expression | None = None,
149+
datums: Expression | None = None,
137150
):
138151
n_dts = 0
139-
for tbl in reader.iterate_tables(filter=datums):
152+
datum_filter = datums.to_arrow() if datums is not None else None
153+
for tbl in reader.iterate_tables(filter=datum_filter):
140154
n_dts += int(np.unique(tbl["datum_id"].to_numpy()).shape[0])
141155
return n_dts
142156

src/valor_lite/filtering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from enum import StrEnum
4-
from typing import Any
4+
from typing import Any, Iterable
55
from zoneinfo import ZoneInfo
66

77
import pyarrow as pa
@@ -167,7 +167,7 @@ def __le__(self, other: Any) -> Expression:
167167
other = other._symbol if isinstance(other, _Symbol) else other
168168
return Expression(self._symbol <= other)
169169

170-
def isin(self, values: set[Any]) -> Expression:
170+
def isin(self, values: Iterable[Any]) -> Expression:
171171
values = {v._symbol if isinstance(v, _Symbol) else v for v in values}
172172
return Expression(self._symbol.isin(values))
173173

0 commit comments

Comments
 (0)