@@ -528,23 +528,6 @@ def iterate_values(self, datums: pc.Expression | None = None):
528528 matches = tbl ["match" ].to_numpy ()
529529 yield ids , scores , winners , matches
530530
531- def iterate_values_with_tables (self , datums : pc .Expression | None = None ):
532- for tbl in self ._reader .iterate_tables (filter = datums ):
533- ids = np .column_stack (
534- [
535- tbl [col ].to_numpy ()
536- for col in [
537- "datum_id" ,
538- "gt_label_id" ,
539- "pd_label_id" ,
540- ]
541- ]
542- )
543- scores = tbl ["pd_score" ].to_numpy ()
544- winners = tbl ["pd_winner" ].to_numpy ()
545- matches = tbl ["match" ].to_numpy ()
546- yield ids , scores , winners , matches , tbl
547-
548531 def compute_rocauc (self ) -> dict [MetricType , list [Metric ]]:
549532 """
550533 Compute ROCAUC.
@@ -723,6 +706,8 @@ def compute_examples(
723706 score_thresholds : list [float ] = [0.0 ],
724707 hardmax : bool = True ,
725708 datums : pc .Expression | None = None ,
709+ limit : int | None = None ,
710+ offset : int = 0 ,
726711 ) -> list [Metric ]:
727712 """
728713 Compute examples per datum.
@@ -737,6 +722,10 @@ def compute_examples(
737722 Toggles whether a hardmax is applied to predictions.
738723 datums : pyarrow.compute.Expression, optional
739724 Option to filter datums by an expression.
725+ limit : int, optional
726+ Option to set a limit to the number of returned datum examples.
727+ offset : int, default=0
728+ Option to offset where examples are being created in the datum index.
740729
741730 Returns
742731 -------
@@ -747,16 +736,29 @@ def compute_examples(
747736 raise ValueError ("At least one score threshold must be passed." )
748737
749738 metrics = []
750- for (
751- ids ,
752- scores ,
753- winners ,
754- _ ,
755- tbl ,
756- ) in self . iterate_values_with_tables ( datums = datums ) :
757- if ids . size == 0 :
739+ for tbl in compute . paginate_index (
740+ source = self . _reader ,
741+ column_key = "datum_id" ,
742+ modifier = datums ,
743+ limit = limit ,
744+ offset = offset ,
745+ ):
746+ if tbl . num_rows == 0 :
758747 continue
759748
749+ ids = np .column_stack (
750+ [
751+ tbl [col ].to_numpy ()
752+ for col in [
753+ "datum_id" ,
754+ "gt_label_id" ,
755+ "pd_label_id" ,
756+ ]
757+ ]
758+ )
759+ scores = tbl ["pd_score" ].to_numpy ()
760+ winners = tbl ["pd_winner" ].to_numpy ()
761+
760762 # extract external identifiers
761763 index_to_datum_id = create_mapping (
762764 tbl , ids , 0 , "datum_id" , "datum_uid"
@@ -829,16 +831,23 @@ def compute_confusion_matrix_with_examples(
829831 )
830832 for score_idx , score_thresh in enumerate (score_thresholds )
831833 }
832- for (
833- ids ,
834- scores ,
835- winners ,
836- _ ,
837- tbl ,
838- ) in self .iterate_values_with_tables (datums = datums ):
839- if ids .size == 0 :
834+ for tbl in self ._reader .iterate_tables (filter = datums ):
835+ if tbl .num_rows == 0 :
840836 continue
841837
838+ ids = np .column_stack (
839+ [
840+ tbl [col ].to_numpy ()
841+ for col in [
842+ "datum_id" ,
843+ "gt_label_id" ,
844+ "pd_label_id" ,
845+ ]
846+ ]
847+ )
848+ scores = tbl ["pd_score" ].to_numpy ()
849+ winners = tbl ["pd_winner" ].to_numpy ()
850+
842851 # extract external identifiers
843852 index_to_datum_id = create_mapping (
844853 tbl , ids , 0 , "datum_id" , "datum_uid"
0 commit comments