1+ """Data selectors."""
12from typing import Dict , List
23
34import numpy as np
67from trinity .buffer .reader .file_reader import _HFBatchReader
78from trinity .buffer .selector .diff_estimator import InterpolationBetaPREstimator
89from trinity .common .config import DataSelectorConfig
10+ from trinity .utils .annotations import Experimental
911from trinity .utils .log import get_logger
1012from trinity .utils .registry import Registry
1113
1214SELECTORS = Registry ("selectors" )
1315
1416
17+ @Experimental
1518class BaseSelector :
19+ """
20+ Abstract base class defining the interface for custom data selection strategies.
21+
22+ A selector determines which samples (by index) are selected from the dataset
23+ during training. It enables flexible sampling beyond simple
24+ sequential or random access, supporting active learning, curriculum learning,
25+ or difficulty-based sampling in the future.
26+
27+ Subclasses must implement:
28+ - get_indices: returns list of indices for next batch
29+ - update: updates internal state using feedback (e.g., loss values, mean rewards, etc.)
30+ - state_dict / load_state_dict: for saving/loading selector state (checkpointing)
31+ """
32+
1633 def __init__ (self , data_source : _HFBatchReader , config : DataSelectorConfig ):
1734 self .data_source = data_source
1835 self .config = config
1936
2037 def get_indices (self , batch_size : int , return_extra_info : bool = False ) -> List [int ]:
38+ """
39+ Select a batch of sample indices from the dataset.
40+
41+ Args:
42+ batch_size (int): Number of indices to return
43+ return_extra_info (bool): If True, may return additional metadata (future use)
44+
45+ Returns:
46+ List[int]: Selected indices into the dataset
47+ """
2148 raise NotImplementedError
2249
2350 def update (self , indices : List [int ], values : List [float ]) -> None :
51+ """
52+ Update internal state based on feedback (e.g., model loss, accuracy).
53+
54+ This allows adaptive selectors (like hard example mining) to learn over time.
55+
56+ Args:
57+ indices (List[int]): Previously selected indices
58+ values (List[float]): Feedback values corresponding to those indices
59+ """
2460 raise NotImplementedError
2561
2662 def state_dict (self ) -> Dict :
63+ """
64+ Return serializable state of the selector for checkpointing.
65+
66+ Returns:
67+ Dict: State information (e.g., current position, etc.)
68+ """
2769 raise NotImplementedError
2870
2971 def load_state_dict (self , state_dict : Dict ) -> None :
72+ """
73+ Restore selector state from a saved dictionary.
74+
75+ Args:
76+ state_dict (Dict): Output from state_dict()
77+ """
3078 raise NotImplementedError
3179
3280
3381@SELECTORS .register_module ("sequential" )
3482class SequentialSelector (BaseSelector ):
83+ """
84+ Selects data sequentially in fixed order across epochs.
85+
86+ Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc., wrapping at epoch boundaries.
87+ Useful for deterministic iteration or when combined with external shuffling.
88+ """
89+
3590 def __init__ (self , data_source : _HFBatchReader , config : DataSelectorConfig ):
3691 super ().__init__ (data_source , config )
3792 self .num_per_epoch = data_source .num_per_epoch
@@ -40,11 +95,14 @@ def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
4095 def get_indices (self , batch_size : int , return_extra_info : bool = False ) -> List [int ]:
4196 start = self .current_index % self .num_per_epoch
4297 end = start + batch_size
43- assert end <= self .num_per_epoch , f"Batch size ({ batch_size } ) is too large"
98+ assert (
99+ end <= self .num_per_epoch
100+ ), f"Batch size ({ batch_size } ) exceeds remaining data in epoch"
44101 self .current_index += batch_size
45102 return list (range (start , end ))
46103
47104 def update (self , indices : List [int ], values : List [float ]) -> None :
105+ # No-op: sequential selection doesn't adapt based on feedback
48106 pass
49107
50108 def state_dict (self ) -> Dict :
@@ -58,29 +116,48 @@ def load_state_dict(self, state_dict):
58116
59117@SELECTORS .register_module ("shuffle" )
60118class ShuffleSelector (BaseSelector ):
119+ """
120+ Shuffles dataset once per epoch and iterates through it sequentially.
121+
122+ Each epoch uses a different permutation of a subset of the full dataset
123+ (of size num_per_epoch). When one epoch ends, a new shuffle is triggered.
124+ Mimics standard PyTorch DataLoader with shuffle=True.
125+ """
126+
61127 def __init__ (self , data_source : _HFBatchReader , config : DataSelectorConfig ):
62128 super ().__init__ (data_source , config )
63- self .dataset_size = data_source .dataset_size
64- self .num_per_epoch = data_source .num_per_epoch
65- self .current_index = 0
66- self .seed = config .seed
67- self .order = self ._get_order ()
129+ self .dataset_size = data_source .dataset_size # Total available samples
130+ self .num_per_epoch = data_source .num_per_epoch # Samples used per epoch
131+ self .current_index = 0 # Progress tracker
132+ self .seed = config .seed # For reproducible shuffling
133+ self .order = self ._get_order () # Current shuffled index order
68134
69135 def _get_order (self ) -> List [int ]:
136+ """
137+ Generate a new shuffled order for the current epoch.
138+
139+ Uses NumPy's PCG64 random generator seeded by epoch number for reproducibility.
140+ Ensures different shuffle per epoch while being deterministic if seed is fixed.
141+ """
70142 rng = np .random .default_rng (self .seed + self .current_index // self .num_per_epoch )
71143 return rng .choice (self .dataset_size , self .num_per_epoch , replace = False )
72144
73145 def get_indices (self , batch_size : int , return_extra_info : bool = False ) -> List [int ]:
74146 start = self .current_index % self .num_per_epoch
75147 end = start + batch_size
76148 assert end <= self .num_per_epoch , f"Batch size ({ batch_size } ) is too large"
149+
150+ # Fetch pre-shuffled indices for this batch
77151 ret = self .order [start :end ]
78152 self .current_index += batch_size
153+
154+ # At end of epoch, reshuffle for next epoch
79155 if self .current_index % self .num_per_epoch == 0 :
80156 self .order = self ._get_order ()
81157 return ret
82158
83159 def update (self , indices : List [int ], values : List [float ]) -> None :
160+ # No-op: static shuffling does not adapt
84161 pass
85162
86163 def state_dict (self ) -> Dict :
@@ -95,6 +172,13 @@ def load_state_dict(self, state_dict):
95172
96173@SELECTORS .register_module ("random" )
97174class RandomSelector (BaseSelector ):
175+ """
176+ Uniformly samples batches randomly with replacement *per batch*.
177+
178+ Unlike ShuffleSelector, there is no concept of an epoch — every batch is independently sampled.
179+ Can result in repeated samples within an epoch. Suitable for online or stochastic training regimes.
180+ """
181+
98182 def __init__ (self , data_source : _HFBatchReader , config : DataSelectorConfig ):
99183 super ().__init__ (data_source , config )
100184 self .dataset_size = data_source .dataset_size
@@ -103,6 +187,7 @@ def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
103187 self .seed = config .seed
104188
105189 def get_indices (self , batch_size , return_extra_info = False ):
190+ # Seed varies per batch to ensure repeatability across runs
106191 rng = np .random .default_rng (self .seed + self .current_index )
107192 selected_indices = rng .choice (self .dataset_size , batch_size , replace = False )
108193 self .current_index += batch_size
@@ -112,6 +197,7 @@ def get_indices(self, batch_size, return_extra_info=False):
112197 return selected_indices
113198
114199 def update (self , indices : List [int ], values : List [float ]) -> None :
200+ # No-op: basic random selection doesn't adapt
115201 pass
116202
117203 def state_dict (self ) -> Dict :
@@ -125,29 +211,59 @@ def load_state_dict(self, state_dict):
125211
126212@SELECTORS .register_module ("offline_easy2hard" )
127213class OfflineEasy2HardSelector (BaseSelector ):
214+ """
215+ Selects samples in an 'easy-to-hard' curriculum based on pre-defined difficulty features.
216+
217+ This selector assumes that higher feature values indicate easier examples.
218+ It sorts all data once at initialization by descending feature value(s), then sequentially
219+ serves batches from easy → hard over epochs. The sorting is fixed (offline), so no online
220+ adaptation occurs during training.
221+
222+ Useful for curriculum learning where sample difficulty is estimated ahead of time
223+ (e.g., via teacher model confidence, length, BLEU score, etc.).
224+ """
225+
128226 def __init__ (self , data_source , config : DataSelectorConfig ):
129227 super ().__init__ (data_source , config )
130228 self .logger = get_logger ("offline_easy2hard_selector" )
131229
230+ # Extract specified feature columns (e.g., 'loss', 'confidence') used to estimate difficulty
132231 feature_keys = config .feature_keys
133232 self .features = np .concatenate (
134233 [np .array (list (data_source .dataset [k ]))[:, None ] for k in feature_keys ], axis = 1
135234 )
235+ # Shape: (N, len(feature_keys)) — one row per sample, one column per feature
236+
237+ # Append index to each feature vector for tracking original positions after sorting
136238 features_with_index = [list (self .features [i ]) + [i ] for i in range (len (self .features ))]
239+
240+ # Sort by feature values in descending order → highest (easiest) first
137241 features_with_index = sorted (features_with_index )[::- 1 ]
138242 self .logger .debug (f"OfflineEasy2HardSelector, sorted { features_with_index [:20 ]} " )
243+
244+ # Store the sorted order of indices (from easiest to hardest)
139245 self .sorted_index = np .array ([i [- 1 ] for i in features_with_index ])
140246
247+ # Number of samples per epoch (may be less than full dataset size)
141248 self .num_per_epoch = data_source .num_per_epoch
142249 self .current_index = 0
143250
144251 def update (self , indices : List [int ], values : List [float ]) -> None :
252+ # No-op: this selector does not adapt based on runtime feedback
145253 pass
146254
147255 def get_indices (self , batch_size , return_extra_info = False ):
256+ """
257+ Returns next batch of indices in curriculum order (easy → hard).
258+
259+ Batches are taken sequentially from the pre-sorted list. When epoch ends,
260+ it wraps around to the beginning (i.e., restarts curriculum).
261+ """
148262 start = self .current_index % self .num_per_epoch
149263 end = start + batch_size
150- assert end <= self .num_per_epoch , f"Batch size ({ batch_size } ) is too large"
264+ assert (
265+ end <= self .num_per_epoch
266+ ), f"Batch size ({ batch_size } ) exceeds available data in epoch"
151267 self .current_index += batch_size
152268 selected_indices = self .sorted_index [start :end ]
153269 if not return_extra_info :
@@ -161,57 +277,109 @@ def get_indices(self, batch_size, return_extra_info=False):
161277 return selected_indices , extra_info
162278
163279 def state_dict (self ) -> Dict :
280+ """
281+ Save current position in the curriculum for checkpointing.
282+ Allows resuming from same point in the easy→hard progression.
283+ """
164284 return {
165285 "current_index" : self .current_index ,
166286 }
167287
168288 def load_state_dict (self , state_dict ):
289+ """
290+ Restore progress through the curriculum from saved state.
291+ """
169292 self .current_index = state_dict .get ("current_index" , 0 )
170293
171294
172295@SELECTORS .register_module ("diff_based" )
173296class DiffBasedSelector (BaseSelector ):
297+ """
298+ Adaptive difficulty-based selector using probabilistic modeling of sample difficulty.
299+
300+ Uses `InterpolationBetaPREstimator` to model each sample's probability of success (PR),
301+ updated with observed feedback (e.g., loss, accuracy). Then selects samples close to
302+ a target reward (e.g., 1.0 for perfect performance), implementing a form of
303+ *targeted difficulty sampling* — focusing on items near the edge of model capability.
304+
305+ Supports both greedy selection (`tau=0`) and stochastic sampling (`tau>0`).
306+ """
307+
174308 def __init__ (self , data_source , config : DataSelectorConfig ) -> None :
175309 super ().__init__ (data_source , config )
176310 self .logger = get_logger ("diff_based_selector" )
177- self .diff_estimator = self .build_diff_estimator (data_source .dataset , config )
311+
312+ # Initialize difficulty estimator using two features (assumed: e.g., correctness & uncertainty)
313+ self .diff_estimator = self .build_diff_estimator (
314+ data_source .dataset , config .feature_keys , config .kwargs
315+ )
178316 self .current_index = 0
179317 self .seed = config .seed
180318
181- def build_diff_estimator (self , dataset , config : DataSelectorConfig ):
319+ self .do_sample = config .kwargs .get (
320+ "do_sample" , False
321+ ) # Whether to sample PR during estimation
322+ self .target_reward = config .kwargs .get ("target_reward" , 1.0 ) # Desired performance level
323+ self .tau = config .kwargs .get ("tau" , 1.0 ) # Temperature for sampling distribution
324+
325+ def build_diff_estimator (self , dataset , feature_keys : List [str ], config : dict ):
326+ """
327+ Constructs a Beta-distribution-based difficulty estimator from features.
328+
329+ Expects exactly two feature keys (e.g., ['correct', 'uncertainty']), which are concatenated
330+ into a feature matrix and passed to InterpolationBetaPREstimator for modeling P(success).
331+ """
182332 self .logger .debug (f"{ config = } " )
183- feature_keys = config .feature_keys
184333 assert len (feature_keys ) == 2
185334 features = np .concatenate (
186335 [np .array (list (dataset [k ]))[:, None ] for k in feature_keys ], axis = 1
187336 )
188337 self .logger .debug (f"{ features .shape = } " )
189338 self .logger .debug (f"{ features [:5 ]= } " )
190- adaptive_rho = hasattr ( config , "adaptive_rho" ) and config . adaptive_rho
339+ adaptive_rho = config . get ( "adaptive_rho" , False )
191340 return InterpolationBetaPREstimator (
192341 features = features ,
193- m = config .m ,
194- lamb = config .lamb ,
195- rho = config .rho ,
342+ m = config .get ( "m" , 16 ) ,
343+ lamb = config .get ( " lamb" , 0.2 ) ,
344+ rho = config .get ( " rho" , 0.2 ) ,
196345 adaptive_rho = adaptive_rho ,
197346 )
198347
199348 def update (self , indices : List [int ], values : List [float ]) -> None :
349+ """
350+ Updates the difficulty estimator with observed performance on selected samples.
351+
352+ Args:
353+ indices (List[int]): Previously selected sample indices
354+ values (List[float]): Observed rewards/scores (e.g., accuracy, BLEU) for those samples
355+ """
200356 self .diff_estimator .update (indices , values )
201357
202358 def get_scores (self ) -> List [float ]:
359+ """
360+ Computes selection scores: negative distance between predicted PR and target reward.
361+
362+ Samples whose predicted performance is closest to `target_reward` receive highest scores.
363+ Encourages selection of "just right" difficulty samples (neither too easy nor too hard).
364+ """
203365 rng = np .random .default_rng (self .seed + self .current_index )
204- predicted_pr = self .diff_estimator .predict_pr (rng = rng , do_sample = self .config . do_sample )
205- scores = - np .abs (self .config . target_reward - predicted_pr )
366+ predicted_pr = self .diff_estimator .predict_pr (rng = rng , do_sample = self .do_sample )
367+ scores = - np .abs (self .target_reward - predicted_pr )
206368 return scores
207369
208370 def get_indices (self , batch_size , return_extra_info = False ):
371+ """
372+ Selects batch of indices based on difficulty proximity to target.
373+
374+ If tau == 0: take top-k highest scoring samples (greedy).
375+ Else: sample stochastically using softmax(logits / tau).
376+ """
209377 sampling_scores = self .get_scores ()
210378 sampling_scores = torch .from_numpy (sampling_scores )
211- if self .config . tau == 0 :
379+ if self .tau == 0 :
212380 selected_indices = torch .topk (sampling_scores , batch_size ).indices
213381 else :
214- sampling_logits = sampling_scores / self .config . tau
382+ sampling_logits = sampling_scores / self .tau
215383 sampling_logits -= sampling_logits .max ()
216384 sampling_probabilities = torch .softmax (sampling_logits , dim = 0 )
217385 rng = torch .Generator ()
@@ -244,9 +412,16 @@ def get_indices(self, batch_size, return_extra_info=False):
244412 return selected_indices
245413
246414 def state_dict (self ) -> Dict :
415+ """
416+ Save current state for checkpointing.
417+ Only tracks sampling progress; actual difficulty estimates are in diff_estimator.
418+ """
247419 return {
248420 "current_index" : self .current_index ,
249421 }
250422
251423 def load_state_dict (self , state_dict ):
424+ """
425+ Restore selector state from checkpoint.
426+ """
252427 self .current_index = state_dict .get ("current_index" , 0 )
0 commit comments