Skip to content

Commit 7e88eff

Browse files
committed
add doc str
1 parent 607f1db commit 7e88eff

File tree

4 files changed

+293
-38
lines changed

4 files changed

+293
-38
lines changed

trinity/buffer/selector/selector.py

Lines changed: 193 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Data selectors."""
12
from typing import Dict, List
23

34
import numpy as np
@@ -6,32 +7,86 @@
67
from trinity.buffer.reader.file_reader import _HFBatchReader
78
from trinity.buffer.selector.diff_estimator import InterpolationBetaPREstimator
89
from trinity.common.config import DataSelectorConfig
10+
from trinity.utils.annotations import Experimental
911
from trinity.utils.log import get_logger
1012
from trinity.utils.registry import Registry
1113

1214
SELECTORS = Registry("selectors")
1315

1416

17+
@Experimental
1518
class BaseSelector:
19+
"""
20+
Abstract base class defining the interface for custom data selection strategies.
21+
22+
A selector determines which samples (by index) are selected from the dataset
23+
during training. It enables flexible sampling beyond simple
24+
sequential or random access, supporting active learning, curriculum learning,
25+
or difficulty-based sampling in the future.
26+
27+
Subclasses must implement:
28+
- get_indices: returns list of indices for next batch
29+
- update: updates internal state using feedback (e.g., loss values, mean rewards, etc.)
30+
- state_dict / load_state_dict: for saving/loading selector state (checkpointing)
31+
"""
32+
1633
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
1734
self.data_source = data_source
1835
self.config = config
1936

2037
def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
38+
"""
39+
Select a batch of sample indices from the dataset.
40+
41+
Args:
42+
batch_size (int): Number of indices to return
43+
return_extra_info (bool): If True, may return additional metadata (future use)
44+
45+
Returns:
46+
List[int]: Selected indices into the dataset
47+
"""
2148
raise NotImplementedError
2249

2350
def update(self, indices: List[int], values: List[float]) -> None:
51+
"""
52+
Update internal state based on feedback (e.g., model loss, accuracy).
53+
54+
This allows adaptive selectors (like hard example mining) to learn over time.
55+
56+
Args:
57+
indices (List[int]): Previously selected indices
58+
values (List[float]): Feedback values corresponding to those indices
59+
"""
2460
raise NotImplementedError
2561

2662
def state_dict(self) -> Dict:
63+
"""
64+
Return serializable state of the selector for checkpointing.
65+
66+
Returns:
67+
Dict: State information (e.g., current position, etc.)
68+
"""
2769
raise NotImplementedError
2870

2971
def load_state_dict(self, state_dict: Dict) -> None:
72+
"""
73+
Restore selector state from a saved dictionary.
74+
75+
Args:
76+
state_dict (Dict): Output from state_dict()
77+
"""
3078
raise NotImplementedError
3179

3280

3381
@SELECTORS.register_module("sequential")
3482
class SequentialSelector(BaseSelector):
83+
"""
84+
Selects data sequentially in fixed order across epochs.
85+
86+
Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc., wrapping at epoch boundaries.
87+
Useful for deterministic iteration or when combined with external shuffling.
88+
"""
89+
3590
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
3691
super().__init__(data_source, config)
3792
self.num_per_epoch = data_source.num_per_epoch
@@ -40,11 +95,14 @@ def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
4095
def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
4196
start = self.current_index % self.num_per_epoch
4297
end = start + batch_size
43-
assert end <= self.num_per_epoch, f"Batch size ({batch_size}) is too large"
98+
assert (
99+
end <= self.num_per_epoch
100+
), f"Batch size ({batch_size}) exceeds remaining data in epoch"
44101
self.current_index += batch_size
45102
return list(range(start, end))
46103

47104
def update(self, indices: List[int], values: List[float]) -> None:
105+
# No-op: sequential selection doesn't adapt based on feedback
48106
pass
49107

50108
def state_dict(self) -> Dict:
@@ -58,29 +116,48 @@ def load_state_dict(self, state_dict):
58116

59117
@SELECTORS.register_module("shuffle")
60118
class ShuffleSelector(BaseSelector):
119+
"""
120+
Shuffles dataset once per epoch and iterates through it sequentially.
121+
122+
Each epoch uses a different permutation of a subset of the full dataset
123+
(of size num_per_epoch). When one epoch ends, a new shuffle is triggered.
124+
Mimics standard PyTorch DataLoader with shuffle=True.
125+
"""
126+
61127
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
62128
super().__init__(data_source, config)
63-
self.dataset_size = data_source.dataset_size
64-
self.num_per_epoch = data_source.num_per_epoch
65-
self.current_index = 0
66-
self.seed = config.seed
67-
self.order = self._get_order()
129+
self.dataset_size = data_source.dataset_size # Total available samples
130+
self.num_per_epoch = data_source.num_per_epoch # Samples used per epoch
131+
self.current_index = 0 # Progress tracker
132+
self.seed = config.seed # For reproducible shuffling
133+
self.order = self._get_order() # Current shuffled index order
68134

69135
def _get_order(self) -> List[int]:
136+
"""
137+
Generate a new shuffled order for the current epoch.
138+
139+
Uses NumPy's PCG64 random generator seeded by epoch number for reproducibility.
140+
Ensures different shuffle per epoch while being deterministic if seed is fixed.
141+
"""
70142
rng = np.random.default_rng(self.seed + self.current_index // self.num_per_epoch)
71143
return rng.choice(self.dataset_size, self.num_per_epoch, replace=False)
72144

73145
def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
74146
start = self.current_index % self.num_per_epoch
75147
end = start + batch_size
76148
assert end <= self.num_per_epoch, f"Batch size ({batch_size}) is too large"
149+
150+
# Fetch pre-shuffled indices for this batch
77151
ret = self.order[start:end]
78152
self.current_index += batch_size
153+
154+
# At end of epoch, reshuffle for next epoch
79155
if self.current_index % self.num_per_epoch == 0:
80156
self.order = self._get_order()
81157
return ret
82158

83159
def update(self, indices: List[int], values: List[float]) -> None:
160+
# No-op: static shuffling does not adapt
84161
pass
85162

86163
def state_dict(self) -> Dict:
@@ -95,6 +172,13 @@ def load_state_dict(self, state_dict):
95172

96173
@SELECTORS.register_module("random")
97174
class RandomSelector(BaseSelector):
175+
"""
176+
Uniformly samples batches randomly with replacement *per batch*.
177+
178+
Unlike ShuffleSelector, there is no concept of an epoch — every batch is independently sampled.
179+
Can result in repeated samples within an epoch. Suitable for online or stochastic training regimes.
180+
"""
181+
98182
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
99183
super().__init__(data_source, config)
100184
self.dataset_size = data_source.dataset_size
@@ -103,6 +187,7 @@ def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
103187
self.seed = config.seed
104188

105189
def get_indices(self, batch_size, return_extra_info=False):
190+
# Seed varies per batch to ensure repeatability across runs
106191
rng = np.random.default_rng(self.seed + self.current_index)
107192
selected_indices = rng.choice(self.dataset_size, batch_size, replace=False)
108193
self.current_index += batch_size
@@ -112,6 +197,7 @@ def get_indices(self, batch_size, return_extra_info=False):
112197
return selected_indices
113198

114199
def update(self, indices: List[int], values: List[float]) -> None:
200+
# No-op: basic random selection doesn't adapt
115201
pass
116202

117203
def state_dict(self) -> Dict:
@@ -125,29 +211,59 @@ def load_state_dict(self, state_dict):
125211

126212
@SELECTORS.register_module("offline_easy2hard")
127213
class OfflineEasy2HardSelector(BaseSelector):
214+
"""
215+
Selects samples in an 'easy-to-hard' curriculum based on pre-defined difficulty features.
216+
217+
This selector assumes that higher feature values indicate easier examples.
218+
It sorts all data once at initialization by descending feature value(s), then sequentially
219+
serves batches from easy → hard over epochs. The sorting is fixed (offline), so no online
220+
adaptation occurs during training.
221+
222+
Useful for curriculum learning where sample difficulty is estimated ahead of time
223+
(e.g., via teacher model confidence, length, BLEU score, etc.).
224+
"""
225+
128226
def __init__(self, data_source, config: DataSelectorConfig):
129227
super().__init__(data_source, config)
130228
self.logger = get_logger("offline_easy2hard_selector")
131229

230+
# Extract specified feature columns (e.g., 'loss', 'confidence') used to estimate difficulty
132231
feature_keys = config.feature_keys
133232
self.features = np.concatenate(
134233
[np.array(list(data_source.dataset[k]))[:, None] for k in feature_keys], axis=1
135234
)
235+
# Shape: (N, len(feature_keys)) — one row per sample, one column per feature
236+
237+
# Append index to each feature vector for tracking original positions after sorting
136238
features_with_index = [list(self.features[i]) + [i] for i in range(len(self.features))]
239+
240+
# Sort by feature values in descending order → highest (easiest) first
137241
features_with_index = sorted(features_with_index)[::-1]
138242
self.logger.debug(f"OfflineEasy2HardSelector, sorted {features_with_index[:20]}")
243+
244+
# Store the sorted order of indices (from easiest to hardest)
139245
self.sorted_index = np.array([i[-1] for i in features_with_index])
140246

247+
# Number of samples per epoch (may be less than full dataset size)
141248
self.num_per_epoch = data_source.num_per_epoch
142249
self.current_index = 0
143250

144251
def update(self, indices: List[int], values: List[float]) -> None:
252+
# No-op: this selector does not adapt based on runtime feedback
145253
pass
146254

147255
def get_indices(self, batch_size, return_extra_info=False):
256+
"""
257+
Returns next batch of indices in curriculum order (easy → hard).
258+
259+
Batches are taken sequentially from the pre-sorted list. When epoch ends,
260+
it wraps around to the beginning (i.e., restarts curriculum).
261+
"""
148262
start = self.current_index % self.num_per_epoch
149263
end = start + batch_size
150-
assert end <= self.num_per_epoch, f"Batch size ({batch_size}) is too large"
264+
assert (
265+
end <= self.num_per_epoch
266+
), f"Batch size ({batch_size}) exceeds available data in epoch"
151267
self.current_index += batch_size
152268
selected_indices = self.sorted_index[start:end]
153269
if not return_extra_info:
@@ -161,57 +277,109 @@ def get_indices(self, batch_size, return_extra_info=False):
161277
return selected_indices, extra_info
162278

163279
def state_dict(self) -> Dict:
280+
"""
281+
Save current position in the curriculum for checkpointing.
282+
Allows resuming from same point in the easy→hard progression.
283+
"""
164284
return {
165285
"current_index": self.current_index,
166286
}
167287

168288
def load_state_dict(self, state_dict):
289+
"""
290+
Restore progress through the curriculum from saved state.
291+
"""
169292
self.current_index = state_dict.get("current_index", 0)
170293

171294

172295
@SELECTORS.register_module("diff_based")
173296
class DiffBasedSelector(BaseSelector):
297+
"""
298+
Adaptive difficulty-based selector using probabilistic modeling of sample difficulty.
299+
300+
Uses `InterpolationBetaPREstimator` to model each sample's probability of success (PR),
301+
updated with observed feedback (e.g., loss, accuracy). Then selects samples close to
302+
a target reward (e.g., 1.0 for perfect performance), implementing a form of
303+
*targeted difficulty sampling* — focusing on items near the edge of model capability.
304+
305+
Supports both greedy selection (`tau=0`) and stochastic sampling (`tau>0`).
306+
"""
307+
174308
def __init__(self, data_source, config: DataSelectorConfig) -> None:
175309
super().__init__(data_source, config)
176310
self.logger = get_logger("diff_based_selector")
177-
self.diff_estimator = self.build_diff_estimator(data_source.dataset, config)
311+
312+
# Initialize difficulty estimator using two features (assumed: e.g., correctness & uncertainty)
313+
self.diff_estimator = self.build_diff_estimator(
314+
data_source.dataset, config.feature_keys, config.kwargs
315+
)
178316
self.current_index = 0
179317
self.seed = config.seed
180318

181-
def build_diff_estimator(self, dataset, config: DataSelectorConfig):
319+
self.do_sample = config.kwargs.get(
320+
"do_sample", False
321+
) # Whether to sample PR during estimation
322+
self.target_reward = config.kwargs.get("target_reward", 1.0) # Desired performance level
323+
self.tau = config.kwargs.get("tau", 1.0) # Temperature for sampling distribution
324+
325+
def build_diff_estimator(self, dataset, feature_keys: List[str], config: dict):
326+
"""
327+
Constructs a Beta-distribution-based difficulty estimator from features.
328+
329+
Expects exactly two feature keys (e.g., ['correct', 'uncertainty']), which are concatenated
330+
into a feature matrix and passed to InterpolationBetaPREstimator for modeling P(success).
331+
"""
182332
self.logger.debug(f"{config=}")
183-
feature_keys = config.feature_keys
184333
assert len(feature_keys) == 2
185334
features = np.concatenate(
186335
[np.array(list(dataset[k]))[:, None] for k in feature_keys], axis=1
187336
)
188337
self.logger.debug(f"{features.shape=}")
189338
self.logger.debug(f"{features[:5]=}")
190-
adaptive_rho = hasattr(config, "adaptive_rho") and config.adaptive_rho
339+
adaptive_rho = config.get("adaptive_rho", False)
191340
return InterpolationBetaPREstimator(
192341
features=features,
193-
m=config.m,
194-
lamb=config.lamb,
195-
rho=config.rho,
342+
m=config.get("m", 16),
343+
lamb=config.get("lamb", 0.2),
344+
rho=config.get("rho", 0.2),
196345
adaptive_rho=adaptive_rho,
197346
)
198347

199348
def update(self, indices: List[int], values: List[float]) -> None:
349+
"""
350+
Updates the difficulty estimator with observed performance on selected samples.
351+
352+
Args:
353+
indices (List[int]): Previously selected sample indices
354+
values (List[float]): Observed rewards/scores (e.g., accuracy, BLEU) for those samples
355+
"""
200356
self.diff_estimator.update(indices, values)
201357

202358
def get_scores(self) -> List[float]:
359+
"""
360+
Computes selection scores: negative distance between predicted PR and target reward.
361+
362+
Samples whose predicted performance is closest to `target_reward` receive highest scores.
363+
Encourages selection of "just right" difficulty samples (neither too easy nor too hard).
364+
"""
203365
rng = np.random.default_rng(self.seed + self.current_index)
204-
predicted_pr = self.diff_estimator.predict_pr(rng=rng, do_sample=self.config.do_sample)
205-
scores = -np.abs(self.config.target_reward - predicted_pr)
366+
predicted_pr = self.diff_estimator.predict_pr(rng=rng, do_sample=self.do_sample)
367+
scores = -np.abs(self.target_reward - predicted_pr)
206368
return scores
207369

208370
def get_indices(self, batch_size, return_extra_info=False):
371+
"""
372+
Selects batch of indices based on difficulty proximity to target.
373+
374+
If tau == 0: take top-k highest scoring samples (greedy).
375+
Else: sample stochastically using softmax(logits / tau).
376+
"""
209377
sampling_scores = self.get_scores()
210378
sampling_scores = torch.from_numpy(sampling_scores)
211-
if self.config.tau == 0:
379+
if self.tau == 0:
212380
selected_indices = torch.topk(sampling_scores, batch_size).indices
213381
else:
214-
sampling_logits = sampling_scores / self.config.tau
382+
sampling_logits = sampling_scores / self.tau
215383
sampling_logits -= sampling_logits.max()
216384
sampling_probabilities = torch.softmax(sampling_logits, dim=0)
217385
rng = torch.Generator()
@@ -244,9 +412,16 @@ def get_indices(self, batch_size, return_extra_info=False):
244412
return selected_indices
245413

246414
def state_dict(self) -> Dict:
415+
"""
416+
Save current state for checkpointing.
417+
Only tracks sampling progress; actual difficulty estimates are in diff_estimator.
418+
"""
247419
return {
248420
"current_index": self.current_index,
249421
}
250422

251423
def load_state_dict(self, state_dict):
424+
"""
425+
Restore selector state from checkpoint.
426+
"""
252427
self.current_index = state_dict.get("current_index", 0)

0 commit comments

Comments
 (0)