Skip to content

Commit 3eba15b

Browse files
authored
add datum filtering to computations (#858)
1 parent e727dc8 commit 3eba15b

File tree

15 files changed

+690
-349
lines changed

15 files changed

+690
-349
lines changed

src/valor_lite/classification/evaluator.py

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77
import pyarrow as pa
88
import pyarrow.compute as pc
9-
from numpy.typing import NDArray
109

1110
from valor_lite.cache import (
1211
FileCacheReader,
@@ -30,10 +29,12 @@
3029
EvaluatorInfo,
3130
decode_metadata_fields,
3231
encode_metadata_fields,
32+
extract_counts,
33+
extract_groundtruth_count_per_label,
34+
extract_labels,
3335
generate_cache_path,
3436
generate_intermediate_cache_path,
3537
generate_intermediate_schema,
36-
generate_meta,
3738
generate_metadata_path,
3839
generate_roc_curve_cache_path,
3940
generate_roc_curve_schema,
@@ -296,11 +297,11 @@ def finalize(
296297
# post-process into sorted writer
297298
reader = self._writer.to_reader()
298299

299-
# generate evaluator meta
300-
(index_to_label, label_counts, info,) = generate_meta(
301-
reader=reader, index_to_label_override=index_to_label_override
300+
# extract labels
301+
index_to_label = extract_labels(
302+
reader=reader,
303+
index_to_label_override=index_to_label_override,
302304
)
303-
info.metadata_fields = self._metadata_fields
304305

305306
self._create_rocauc_intermediate(
306307
reader=reader,
@@ -312,9 +313,8 @@ def finalize(
312313
return Evaluator(
313314
reader=reader,
314315
roc_curve_reader=roc_curve_reader,
315-
info=info,
316-
label_counts=label_counts,
317316
index_to_label=index_to_label,
317+
metadata_fields=self._metadata_fields,
318318
)
319319

320320

@@ -323,19 +323,31 @@ def __init__(
323323
self,
324324
reader: MemoryCacheReader | FileCacheReader,
325325
roc_curve_reader: MemoryCacheReader | FileCacheReader,
326-
info: EvaluatorInfo,
327-
label_counts: NDArray[np.uint64],
328326
index_to_label: dict[int, str],
327+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
329328
):
330329
self._reader = reader
331330
self._roc_curve_reader = roc_curve_reader
332-
self._info = info
333-
self._label_counts = label_counts
334331
self._index_to_label = index_to_label
332+
self._metadata_fields = metadata_fields
335333

336334
@property
337335
def info(self) -> EvaluatorInfo:
338-
return self._info
336+
return self.get_info()
337+
338+
def get_info(
339+
self,
340+
datums: pc.Expression | None = None,
341+
) -> EvaluatorInfo:
342+
info = EvaluatorInfo()
343+
info.metadata_fields = self._metadata_fields
344+
info.number_of_rows = self._reader.count_rows()
345+
info.number_of_labels = len(self._index_to_label)
346+
info.number_of_datums = extract_counts(
347+
reader=self._reader,
348+
datums=datums,
349+
)
350+
return info
339351

340352
@classmethod
341353
def load(
@@ -369,25 +381,24 @@ def load(
369381
generate_roc_curve_cache_path(path)
370382
)
371383

372-
# build evaluator meta
373-
(
374-
index_to_label,
375-
label_counts,
376-
info,
377-
) = generate_meta(reader, index_to_label_override)
384+
# extract labels
385+
index_to_label = extract_labels(
386+
reader=reader,
387+
index_to_label_override=index_to_label_override,
388+
)
378389

379390
# read config
380391
metadata_path = generate_metadata_path(path)
392+
metadata_fields = None
381393
with open(metadata_path, "r") as f:
382394
encoded_types = json.load(f)
383-
info.metadata_fields = decode_metadata_fields(encoded_types)
395+
metadata_fields = decode_metadata_fields(encoded_types)
384396

385397
return cls(
386398
reader=reader,
387399
roc_curve_reader=roc_curve_reader,
388-
info=info,
389-
label_counts=label_counts,
390400
index_to_label=index_to_label,
401+
metadata_fields=metadata_fields,
391402
)
392403

393404
def filter(
@@ -492,7 +503,7 @@ def filter(
492503

493504
return loader.finalize(index_to_label_override=self._index_to_label)
494505

495-
def iterate_values(self):
506+
def iterate_values(self, datums: pc.Expression | None = None):
496507
columns = [
497508
"datum_id",
498509
"gt_label_id",
@@ -501,7 +512,7 @@ def iterate_values(self):
501512
"pd_winner",
502513
"match",
503514
]
504-
for tbl in self._reader.iterate_tables(columns=columns):
515+
for tbl in self._reader.iterate_tables(columns=columns, filter=datums):
505516
ids = np.column_stack(
506517
[
507518
tbl[col].to_numpy()
@@ -517,8 +528,8 @@ def iterate_values(self):
517528
matches = tbl["match"].to_numpy()
518529
yield ids, scores, winners, matches
519530

520-
def iterate_values_with_tables(self):
521-
for tbl in self._reader.iterate_tables():
531+
def iterate_values_with_tables(self, datums: pc.Expression | None = None):
532+
for tbl in self._reader.iterate_tables(filter=datums):
522533
ids = np.column_stack(
523534
[
524535
tbl[col].to_numpy()
@@ -534,10 +545,17 @@ def iterate_values_with_tables(self):
534545
matches = tbl["match"].to_numpy()
535546
yield ids, scores, winners, matches, tbl
536547

537-
def compute_rocauc(self) -> dict[MetricType, list[Metric]]:
548+
def compute_rocauc(
549+
self, datums: pc.Expression | None = None
550+
) -> dict[MetricType, list[Metric]]:
538551
"""
539552
Compute ROCAUC.
540553
554+
Parameters
555+
----------
556+
datums : pyarrow.compute.Expression, optional
557+
Option to filter datums by an expression.
558+
541559
Returns
542560
-------
543561
dict[MetricType, list[Metric]]
@@ -546,20 +564,26 @@ def compute_rocauc(self) -> dict[MetricType, list[Metric]]:
546564
n_labels = self.info.number_of_labels
547565

548566
rocauc = np.zeros(n_labels, dtype=np.float64)
567+
label_counts = extract_groundtruth_count_per_label(
568+
reader=self._reader,
569+
number_of_labels=len(self._index_to_label),
570+
datums=datums,
571+
)
549572

550573
prev = np.zeros((n_labels, 2), dtype=np.uint64)
551574
for array in self._roc_curve_reader.iterate_arrays(
552575
numeric_columns=[
553576
"pd_label_id",
554577
"cumulative_fp",
555578
"cumulative_tp",
556-
]
579+
],
580+
filter=datums,
557581
):
558582
rocauc, prev = compute_rocauc(
559583
rocauc=rocauc,
560584
array=array,
561-
gt_count_per_label=self._label_counts[:, 0],
562-
pd_count_per_label=self._label_counts[:, 1],
585+
gt_count_per_label=label_counts[:, 0],
586+
pd_count_per_label=label_counts[:, 1],
563587
n_labels=self.info.number_of_labels,
564588
prev=prev,
565589
)
@@ -576,6 +600,7 @@ def compute_precision_recall(
576600
self,
577601
score_thresholds: list[float] = [0.0],
578602
hardmax: bool = True,
603+
datums: pc.Expression | None = None,
579604
) -> dict[MetricType, list]:
580605
"""
581606
Performs an evaluation and returns metrics.
@@ -586,10 +611,8 @@ def compute_precision_recall(
586611
A list of score thresholds to compute metrics over.
587612
hardmax : bool
588613
Toggles whether a hardmax is applied to predictions.
589-
rows_per_chunk : int, default=10_000
590-
The number of sorted rows to return in each chunk.
591-
read_batch_size : int, default=1_000
592-
The maximum number of rows to load in-memory per file.
614+
datums : pyarrow.compute.Expression, optional
615+
Option to filter datums by an expression.
593616
594617
Returns
595618
-------
@@ -606,7 +629,7 @@ def compute_precision_recall(
606629
# intermediates
607630
counts = np.zeros((n_scores, n_labels, 4), dtype=np.uint64)
608631

609-
for ids, scores, winners, _ in self.iterate_values():
632+
for ids, scores, winners, _ in self.iterate_values(datums=datums):
610633
batch_counts = compute_counts(
611634
ids=ids,
612635
scores=scores,
@@ -637,6 +660,7 @@ def compute_confusion_matrix(
637660
self,
638661
score_thresholds: list[float] = [0.0],
639662
hardmax: bool = True,
663+
datums: pc.Expression | None = None,
640664
) -> list[Metric]:
641665
"""
642666
Compute a confusion matrix.
@@ -647,6 +671,8 @@ def compute_confusion_matrix(
647671
A list of score thresholds to compute metrics over.
648672
hardmax : bool
649673
Toggles whether a hardmax is applied to predictions.
674+
datums : pyarrow.compute.Expression, optional
675+
Option to filter datums by an expression.
650676
651677
Returns
652678
-------
@@ -664,7 +690,9 @@ def compute_confusion_matrix(
664690
unmatched_groundtruths = np.zeros(
665691
(n_scores, n_labels), dtype=np.uint64
666692
)
667-
for ids, scores, winners, matches in self.iterate_values():
693+
for ids, scores, winners, matches in self.iterate_values(
694+
datums=datums
695+
):
668696
(
669697
mask_tp,
670698
mask_fp_fn_misclf,
@@ -700,6 +728,7 @@ def compute_examples(
700728
self,
701729
score_thresholds: list[float] = [0.0],
702730
hardmax: bool = True,
731+
datums: pc.Expression | None = None,
703732
) -> list[Metric]:
704733
"""
705734
Compute examples per datum.
@@ -712,6 +741,8 @@ def compute_examples(
712741
A list of score thresholds to compute metrics over.
713742
hardmax : bool
714743
Toggles whether a hardmax is applied to predictions.
744+
datums : pyarrow.compute.Expression, optional
745+
Option to filter datums by an expression.
715746
716747
Returns
717748
-------
@@ -726,9 +757,9 @@ def compute_examples(
726757
ids,
727758
scores,
728759
winners,
729-
matches,
760+
_,
730761
tbl,
731-
) in self.iterate_values_with_tables():
762+
) in self.iterate_values_with_tables(datums=datums):
732763
if ids.size == 0:
733764
continue
734765

@@ -770,6 +801,7 @@ def compute_confusion_matrix_with_examples(
770801
self,
771802
score_thresholds: list[float] = [0.0],
772803
hardmax: bool = True,
804+
datums: pc.Expression | None = None,
773805
) -> list[Metric]:
774806
"""
775807
Compute confusion matrix with examples.
@@ -784,6 +816,8 @@ def compute_confusion_matrix_with_examples(
784816
A list of score thresholds to compute metrics over.
785817
hardmax : bool
786818
Toggles whether a hardmax is applied to predictions.
819+
datums : pyarrow.compute.Expression, optional
820+
Option to filter datums by an expression.
787821
788822
Returns
789823
-------
@@ -805,9 +839,9 @@ def compute_confusion_matrix_with_examples(
805839
ids,
806840
scores,
807841
winners,
808-
matches,
842+
_,
809843
tbl,
810-
) in self.iterate_values_with_tables():
844+
) in self.iterate_values_with_tables(datums=datums):
811845
if ids.size == 0:
812846
continue
813847

0 commit comments

Comments
 (0)