5050 unpack_rocauc ,
5151)
5252from valor_lite .exceptions import EmptyCacheError
53+ from valor_lite .filtering import DataType , Expression
5354
5455
5556class Builder :
@@ -58,7 +59,9 @@ def __init__(
5859 writer : MemoryCacheWriter | FileCacheWriter ,
5960 roc_curve_writer : MemoryCacheWriter | FileCacheWriter ,
6061 intermediate_writer : MemoryCacheWriter | FileCacheWriter ,
61- metadata_fields : list [tuple [str , str | pa .DataType ]] | None = None ,
62+ metadata_fields : list [tuple [str , DataType ]]
63+ | list [tuple [str , str ]]
64+ | None = None ,
6265 ):
6366 self ._writer = writer
6467 self ._roc_curve_writer = roc_curve_writer
@@ -69,7 +72,9 @@ def __init__(
6972 def in_memory (
7073 cls ,
7174 batch_size : int = 10_000 ,
72- metadata_fields : list [tuple [str , str | pa .DataType ]] | None = None ,
75+ metadata_fields : list [tuple [str , DataType ]]
76+ | list [tuple [str , str ]]
77+ | None = None ,
7378 ):
7479 """
7580 Create an in-memory evaluator cache.
@@ -78,7 +83,7 @@ def in_memory(
7883 ----------
7984 batch_size : int, default=10_000
8085 The target number of rows to buffer before writing to the cache. Defaults to 10_000.
81- metadata_fields : list[tuple[str, str | pa. DataType]], optional
86+ metadata_fields : list[tuple[str, str | DataType]], optional
8287 Optional metadata field definitions.
8388 """
8489 writer = MemoryCacheWriter .create (
@@ -107,7 +112,9 @@ def persistent(
107112 batch_size : int = 10_000 ,
108113 rows_per_file : int = 100_000 ,
109114 compression : str = "snappy" ,
110- metadata_fields : list [tuple [str , str | pa .DataType ]] | None = None ,
115+ metadata_fields : list [tuple [str , DataType ]]
116+ | list [tuple [str , str ]]
117+ | None = None ,
111118 ):
112119 """
113120 Create a persistent file-based evaluator cache.
@@ -122,7 +129,7 @@ def persistent(
122129 Sets the maximum number of rows per file. This may be exceeded as files are datum aligned.
123130 compression : str, default="snappy"
124131 Sets the pyarrow compression method.
125- metadata_fields : list[tuple[str, str | pa. DataType]], optional
132+ metadata_fields : list[tuple[str, str | DataType]], optional
126133 Optionally sets metadata description for use in filtering.
127134 """
128135 path = Path (path )
@@ -324,20 +331,26 @@ def __init__(
324331 reader : MemoryCacheReader | FileCacheReader ,
325332 roc_curve_reader : MemoryCacheReader | FileCacheReader ,
326333 index_to_label : dict [int , str ],
327- metadata_fields : list [tuple [str , str | pa .DataType ]] | None = None ,
334+ metadata_fields : list [tuple [str , str ]]
335+ | list [tuple [str , DataType ]]
336+ | None = None ,
328337 ):
329338 self ._reader = reader
330339 self ._roc_curve_reader = roc_curve_reader
331340 self ._index_to_label = index_to_label
332- self ._metadata_fields = metadata_fields
341+ self ._metadata_fields = (
342+ [(name , str (dtype )) for name , dtype in metadata_fields ]
343+ if metadata_fields
344+ else None
345+ )
333346
334347 @property
335348 def info (self ) -> EvaluatorInfo :
336349 return self .get_info ()
337350
338351 def get_info (
339352 self ,
340- datums : pc . Expression | None = None ,
353+ datums : Expression | None = None ,
341354 ) -> EvaluatorInfo :
342355 info = EvaluatorInfo ()
343356 info .metadata_fields = self ._metadata_fields
@@ -403,21 +416,21 @@ def load(
403416
404417 def filter (
405418 self ,
406- datums : pc . Expression | None = None ,
407- groundtruths : pc . Expression | None = None ,
408- predictions : pc . Expression | None = None ,
419+ datums : Expression | None = None ,
420+ groundtruths : Expression | None = None ,
421+ predictions : Expression | None = None ,
409422 path : str | Path | None = None ,
410423 ) -> Evaluator :
411424 """
412425 Filter evaluator cache.
413426
414427 Parameters
415428 ----------
416- datums : pc. Expression | None = None
429+ datums : Expression | None = None
417430 A filter expression used to filter datums.
418- groundtruths : pc. Expression | None = None
431+ groundtruths : Expression | None = None
419432 A filter expression used to filter ground truth annotations.
420- predictions : pc. Expression | None = None
433+ predictions : Expression | None = None
421434 A filter expression used to filter predictions.
422435 path : str | Path, optional
423436 Where to store the filtered cache if storing on disk.
@@ -447,7 +460,8 @@ def filter(
447460 metadata_fields = self .info .metadata_fields ,
448461 )
449462
450- for tbl in self ._reader .iterate_tables (filter = datums ):
463+ datum_filter = datums .to_arrow () if datums is not None else None
464+ for tbl in self ._reader .iterate_tables (filter = datum_filter ):
451465 columns = (
452466 "datum_id" ,
453467 "gt_label_id" ,
@@ -461,7 +475,7 @@ def filter(
461475
462476 if groundtruths is not None :
463477 mask_valid_gt = np .zeros (n_pairs , dtype = np .bool_ )
464- gt_tbl = tbl .filter (groundtruths )
478+ gt_tbl = tbl .filter (groundtruths . to_arrow () )
465479 gt_pairs = np .column_stack (
466480 [
467481 gt_tbl [col ].to_numpy ()
@@ -475,7 +489,7 @@ def filter(
475489
476490 if predictions is not None :
477491 mask_valid_pd = np .zeros (n_pairs , dtype = np .bool_ )
478- pd_tbl = tbl .filter (predictions )
492+ pd_tbl = tbl .filter (predictions . to_arrow () )
479493 pd_pairs = np .column_stack (
480494 [
481495 pd_tbl [col ].to_numpy ()
@@ -503,7 +517,7 @@ def filter(
503517
504518 return loader .finalize (index_to_label_override = self ._index_to_label )
505519
506- def iterate_values (self , datums : pc .Expression | None = None ):
520+ def _iterate_values (self , datum_filter : pc .Expression | None = None ):
507521 columns = [
508522 "datum_id" ,
509523 "gt_label_id" ,
@@ -512,7 +526,9 @@ def iterate_values(self, datums: pc.Expression | None = None):
512526 "pd_winner" ,
513527 "match" ,
514528 ]
515- for tbl in self ._reader .iterate_tables (columns = columns , filter = datums ):
529+ for tbl in self ._reader .iterate_tables (
530+ columns = columns , filter = datum_filter
531+ ):
516532 ids = np .column_stack (
517533 [
518534 tbl [col ].to_numpy ()
@@ -528,8 +544,10 @@ def iterate_values(self, datums: pc.Expression | None = None):
528544 matches = tbl ["match" ].to_numpy ()
529545 yield ids , scores , winners , matches
530546
531- def iterate_values_with_tables (self , datums : pc .Expression | None = None ):
532- for tbl in self ._reader .iterate_tables (filter = datums ):
547+ def _iterate_values_with_tables (
548+ self , datum_filter : pc .Expression | None = None
549+ ):
550+ for tbl in self ._reader .iterate_tables (filter = datum_filter ):
533551 ids = np .column_stack (
534552 [
535553 tbl [col ].to_numpy ()
@@ -594,7 +612,7 @@ def compute_precision_recall(
594612 self ,
595613 score_thresholds : list [float ] = [0.0 ],
596614 hardmax : bool = True ,
597- datums : pc . Expression | None = None ,
615+ datums : Expression | None = None ,
598616 ) -> dict [MetricType , list ]:
599617 """
600618 Performs an evaluation and returns metrics.
@@ -613,6 +631,7 @@ def compute_precision_recall(
613631 dict[MetricType, list]
614632 A dictionary mapping MetricType enumerations to lists of computed metrics.
615633 """
634+ datum_filter = datums .to_arrow () if datums is not None else None
616635 if not score_thresholds :
617636 raise ValueError ("At least one score threshold must be passed." )
618637
@@ -623,7 +642,9 @@ def compute_precision_recall(
623642 # intermediates
624643 counts = np .zeros ((n_scores , n_labels , 4 ), dtype = np .uint64 )
625644
626- for ids , scores , winners , _ in self .iterate_values (datums = datums ):
645+ for ids , scores , winners , _ in self ._iterate_values (
646+ datum_filter = datum_filter
647+ ):
627648 batch_counts = compute_counts (
628649 ids = ids ,
629650 scores = scores ,
@@ -654,7 +675,7 @@ def compute_confusion_matrix(
654675 self ,
655676 score_thresholds : list [float ] = [0.0 ],
656677 hardmax : bool = True ,
657- datums : pc . Expression | None = None ,
678+ datums : Expression | None = None ,
658679 ) -> list [Metric ]:
659680 """
660681 Compute a confusion matrix.
@@ -673,6 +694,7 @@ def compute_confusion_matrix(
673694 list[Metric]
674695 A list of confusion matrices.
675696 """
697+ datum_filter = datums .to_arrow () if datums is not None else None
676698 if not score_thresholds :
677699 raise ValueError ("At least one score threshold must be passed." )
678700
@@ -684,8 +706,8 @@ def compute_confusion_matrix(
684706 unmatched_groundtruths = np .zeros (
685707 (n_scores , n_labels ), dtype = np .uint64
686708 )
687- for ids , scores , winners , matches in self .iterate_values (
688- datums = datums
709+ for ids , scores , winners , matches in self ._iterate_values (
710+ datum_filter = datum_filter
689711 ):
690712 (
691713 mask_tp ,
@@ -722,7 +744,7 @@ def compute_examples(
722744 self ,
723745 score_thresholds : list [float ] = [0.0 ],
724746 hardmax : bool = True ,
725- datums : pc . Expression | None = None ,
747+ datums : Expression | None = None ,
726748 ) -> list [Metric ]:
727749 """
728750 Compute examples per datum.
@@ -743,6 +765,7 @@ def compute_examples(
743765 list[Metric]
744766 A list of confusion matrices.
745767 """
768+ datum_filter = datums .to_arrow () if datums is not None else None
746769 if not score_thresholds :
747770 raise ValueError ("At least one score threshold must be passed." )
748771
@@ -753,7 +776,7 @@ def compute_examples(
753776 winners ,
754777 _ ,
755778 tbl ,
756- ) in self .iterate_values_with_tables ( datums = datums ):
779+ ) in self ._iterate_values_with_tables ( datum_filter = datum_filter ):
757780 if ids .size == 0 :
758781 continue
759782
@@ -795,7 +818,7 @@ def compute_confusion_matrix_with_examples(
795818 self ,
796819 score_thresholds : list [float ] = [0.0 ],
797820 hardmax : bool = True ,
798- datums : pc . Expression | None = None ,
821+ datums : Expression | None = None ,
799822 ) -> list [Metric ]:
800823 """
801824 Compute confusion matrix with examples.
@@ -818,6 +841,7 @@ def compute_confusion_matrix_with_examples(
818841 list[Metric]
819842 A list of confusion matrices.
820843 """
844+ datum_filter = datums .to_arrow () if datums is not None else None
821845 if not score_thresholds :
822846 raise ValueError ("At least one score threshold must be passed." )
823847
@@ -835,7 +859,7 @@ def compute_confusion_matrix_with_examples(
835859 winners ,
836860 _ ,
837861 tbl ,
838- ) in self .iterate_values_with_tables ( datums = datums ):
862+ ) in self ._iterate_values_with_tables ( datum_filter = datum_filter ):
839863 if ids .size == 0 :
840864 continue
841865
0 commit comments