66import numpy as np
77import pyarrow as pa
88import pyarrow .compute as pc
9- from numpy .typing import NDArray
109
1110from valor_lite .cache import (
1211 FileCacheReader ,
3029 EvaluatorInfo ,
3130 decode_metadata_fields ,
3231 encode_metadata_fields ,
32+ extract_counts ,
33+ extract_groundtruth_count_per_label ,
34+ extract_labels ,
3335 generate_cache_path ,
3436 generate_intermediate_cache_path ,
3537 generate_intermediate_schema ,
36- generate_meta ,
3738 generate_metadata_path ,
3839 generate_roc_curve_cache_path ,
3940 generate_roc_curve_schema ,
@@ -296,11 +297,11 @@ def finalize(
296297 # post-process into sorted writer
297298 reader = self ._writer .to_reader ()
298299
299- # generate evaluator meta
300- (index_to_label , label_counts , info ,) = generate_meta (
301- reader = reader , index_to_label_override = index_to_label_override
300+ # extract labels
301+ index_to_label = extract_labels (
302+ reader = reader ,
303+ index_to_label_override = index_to_label_override ,
302304 )
303- info .metadata_fields = self ._metadata_fields
304305
305306 self ._create_rocauc_intermediate (
306307 reader = reader ,
@@ -312,9 +313,8 @@ def finalize(
312313 return Evaluator (
313314 reader = reader ,
314315 roc_curve_reader = roc_curve_reader ,
315- info = info ,
316- label_counts = label_counts ,
317316 index_to_label = index_to_label ,
317+ metadata_fields = self ._metadata_fields ,
318318 )
319319
320320
@@ -323,19 +323,31 @@ def __init__(
323323 self ,
324324 reader : MemoryCacheReader | FileCacheReader ,
325325 roc_curve_reader : MemoryCacheReader | FileCacheReader ,
326- info : EvaluatorInfo ,
327- label_counts : NDArray [np .uint64 ],
328326 index_to_label : dict [int , str ],
327+ metadata_fields : list [tuple [str , str | pa .DataType ]] | None = None ,
329328 ):
330329 self ._reader = reader
331330 self ._roc_curve_reader = roc_curve_reader
332- self ._info = info
333- self ._label_counts = label_counts
334331 self ._index_to_label = index_to_label
332+ self ._metadata_fields = metadata_fields
335333
336334 @property
337335 def info (self ) -> EvaluatorInfo :
338- return self ._info
336+ return self .get_info ()
337+
338+ def get_info (
339+ self ,
340+ datums : pc .Expression | None = None ,
341+ ) -> EvaluatorInfo :
342+ info = EvaluatorInfo ()
343+ info .metadata_fields = self ._metadata_fields
344+ info .number_of_rows = self ._reader .count_rows ()
345+ info .number_of_labels = len (self ._index_to_label )
346+ info .number_of_datums = extract_counts (
347+ reader = self ._reader ,
348+ datums = datums ,
349+ )
350+ return info
339351
340352 @classmethod
341353 def load (
@@ -369,25 +381,24 @@ def load(
369381 generate_roc_curve_cache_path (path )
370382 )
371383
372- # build evaluator meta
373- (
374- index_to_label ,
375- label_counts ,
376- info ,
377- ) = generate_meta (reader , index_to_label_override )
384+ # extract labels
385+ index_to_label = extract_labels (
386+ reader = reader ,
387+ index_to_label_override = index_to_label_override ,
388+ )
378389
379390 # read config
380391 metadata_path = generate_metadata_path (path )
392+ metadata_fields = None
381393 with open (metadata_path , "r" ) as f :
382394 encoded_types = json .load (f )
383- info . metadata_fields = decode_metadata_fields (encoded_types )
395+ metadata_fields = decode_metadata_fields (encoded_types )
384396
385397 return cls (
386398 reader = reader ,
387399 roc_curve_reader = roc_curve_reader ,
388- info = info ,
389- label_counts = label_counts ,
390400 index_to_label = index_to_label ,
401+ metadata_fields = metadata_fields ,
391402 )
392403
393404 def filter (
@@ -492,7 +503,7 @@ def filter(
492503
493504 return loader .finalize (index_to_label_override = self ._index_to_label )
494505
495- def iterate_values (self ):
506+ def iterate_values (self , datums : pc . Expression | None = None ):
496507 columns = [
497508 "datum_id" ,
498509 "gt_label_id" ,
@@ -501,7 +512,7 @@ def iterate_values(self):
501512 "pd_winner" ,
502513 "match" ,
503514 ]
504- for tbl in self ._reader .iterate_tables (columns = columns ):
515+ for tbl in self ._reader .iterate_tables (columns = columns , filter = datums ):
505516 ids = np .column_stack (
506517 [
507518 tbl [col ].to_numpy ()
@@ -517,8 +528,8 @@ def iterate_values(self):
517528 matches = tbl ["match" ].to_numpy ()
518529 yield ids , scores , winners , matches
519530
520- def iterate_values_with_tables (self ):
521- for tbl in self ._reader .iterate_tables ():
531+ def iterate_values_with_tables (self , datums : pc . Expression | None = None ):
532+ for tbl in self ._reader .iterate_tables (filter = datums ):
522533 ids = np .column_stack (
523534 [
524535 tbl [col ].to_numpy ()
@@ -534,10 +545,17 @@ def iterate_values_with_tables(self):
534545 matches = tbl ["match" ].to_numpy ()
535546 yield ids , scores , winners , matches , tbl
536547
537- def compute_rocauc (self ) -> dict [MetricType , list [Metric ]]:
548+ def compute_rocauc (
549+ self , datums : pc .Expression | None = None
550+ ) -> dict [MetricType , list [Metric ]]:
538551 """
539552 Compute ROCAUC.
540553
554+ Parameters
555+ ----------
556+ datums : pyarrow.compute.Expression, optional
557+ Option to filter datums by an expression.
558+
541559 Returns
542560 -------
543561 dict[MetricType, list[Metric]]
@@ -546,20 +564,26 @@ def compute_rocauc(self) -> dict[MetricType, list[Metric]]:
546564 n_labels = self .info .number_of_labels
547565
548566 rocauc = np .zeros (n_labels , dtype = np .float64 )
567+ label_counts = extract_groundtruth_count_per_label (
568+ reader = self ._reader ,
569+ number_of_labels = len (self ._index_to_label ),
570+ datums = datums ,
571+ )
549572
550573 prev = np .zeros ((n_labels , 2 ), dtype = np .uint64 )
551574 for array in self ._roc_curve_reader .iterate_arrays (
552575 numeric_columns = [
553576 "pd_label_id" ,
554577 "cumulative_fp" ,
555578 "cumulative_tp" ,
556- ]
579+ ],
580+ filter = datums ,
557581 ):
558582 rocauc , prev = compute_rocauc (
559583 rocauc = rocauc ,
560584 array = array ,
561- gt_count_per_label = self . _label_counts [:, 0 ],
562- pd_count_per_label = self . _label_counts [:, 1 ],
585+ gt_count_per_label = label_counts [:, 0 ],
586+ pd_count_per_label = label_counts [:, 1 ],
563587 n_labels = self .info .number_of_labels ,
564588 prev = prev ,
565589 )
@@ -576,6 +600,7 @@ def compute_precision_recall(
576600 self ,
577601 score_thresholds : list [float ] = [0.0 ],
578602 hardmax : bool = True ,
603+ datums : pc .Expression | None = None ,
579604 ) -> dict [MetricType , list ]:
580605 """
581606 Performs an evaluation and returns metrics.
@@ -586,10 +611,8 @@ def compute_precision_recall(
586611 A list of score thresholds to compute metrics over.
587612 hardmax : bool
588613 Toggles whether a hardmax is applied to predictions.
589- rows_per_chunk : int, default=10_000
590- The number of sorted rows to return in each chunk.
591- read_batch_size : int, default=1_000
592- The maximum number of rows to load in-memory per file.
614+ datums : pyarrow.compute.Expression, optional
615+ Option to filter datums by an expression.
593616
594617 Returns
595618 -------
@@ -606,7 +629,7 @@ def compute_precision_recall(
606629 # intermediates
607630 counts = np .zeros ((n_scores , n_labels , 4 ), dtype = np .uint64 )
608631
609- for ids , scores , winners , _ in self .iterate_values ():
632+ for ids , scores , winners , _ in self .iterate_values (datums = datums ):
610633 batch_counts = compute_counts (
611634 ids = ids ,
612635 scores = scores ,
@@ -637,6 +660,7 @@ def compute_confusion_matrix(
637660 self ,
638661 score_thresholds : list [float ] = [0.0 ],
639662 hardmax : bool = True ,
663+ datums : pc .Expression | None = None ,
640664 ) -> list [Metric ]:
641665 """
642666 Compute a confusion matrix.
@@ -647,6 +671,8 @@ def compute_confusion_matrix(
647671 A list of score thresholds to compute metrics over.
648672 hardmax : bool
649673 Toggles whether a hardmax is applied to predictions.
674+ datums : pyarrow.compute.Expression, optional
675+ Option to filter datums by an expression.
650676
651677 Returns
652678 -------
@@ -664,7 +690,9 @@ def compute_confusion_matrix(
664690 unmatched_groundtruths = np .zeros (
665691 (n_scores , n_labels ), dtype = np .uint64
666692 )
667- for ids , scores , winners , matches in self .iterate_values ():
693+ for ids , scores , winners , matches in self .iterate_values (
694+ datums = datums
695+ ):
668696 (
669697 mask_tp ,
670698 mask_fp_fn_misclf ,
@@ -700,6 +728,7 @@ def compute_examples(
700728 self ,
701729 score_thresholds : list [float ] = [0.0 ],
702730 hardmax : bool = True ,
731+ datums : pc .Expression | None = None ,
703732 ) -> list [Metric ]:
704733 """
705734 Compute examples per datum.
@@ -712,6 +741,8 @@ def compute_examples(
712741 A list of score thresholds to compute metrics over.
713742 hardmax : bool
714743 Toggles whether a hardmax is applied to predictions.
744+ datums : pyarrow.compute.Expression, optional
745+ Option to filter datums by an expression.
715746
716747 Returns
717748 -------
@@ -726,9 +757,9 @@ def compute_examples(
726757 ids ,
727758 scores ,
728759 winners ,
729- matches ,
760+ _ ,
730761 tbl ,
731- ) in self .iterate_values_with_tables ():
762+ ) in self .iterate_values_with_tables (datums = datums ):
732763 if ids .size == 0 :
733764 continue
734765
@@ -770,6 +801,7 @@ def compute_confusion_matrix_with_examples(
770801 self ,
771802 score_thresholds : list [float ] = [0.0 ],
772803 hardmax : bool = True ,
804+ datums : pc .Expression | None = None ,
773805 ) -> list [Metric ]:
774806 """
775807 Compute confusion matrix with examples.
@@ -784,6 +816,8 @@ def compute_confusion_matrix_with_examples(
784816 A list of score thresholds to compute metrics over.
785817 hardmax : bool
786818 Toggles whether a hardmax is applied to predictions.
819+ datums : pyarrow.compute.Expression, optional
820+ Option to filter datums by an expression.
787821
788822 Returns
789823 -------
@@ -805,9 +839,9 @@ def compute_confusion_matrix_with_examples(
805839 ids ,
806840 scores ,
807841 winners ,
808- matches ,
842+ _ ,
809843 tbl ,
810- ) in self .iterate_values_with_tables ():
844+ ) in self .iterate_values_with_tables (datums = datums ):
811845 if ids .size == 0 :
812846 continue
813847
0 commit comments