Skip to content

Commit 41bd02f

Browse files
committed
fix unittest
1 parent a24ad19 commit 41bd02f

File tree

8 files changed

+136
-109
lines changed

8 files changed

+136
-109
lines changed

tests/buffer/task_scheduler_test.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,125 +43,145 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
4343
{"selector_type": "sequential"},
4444
[
4545
{"index": 0, "taskset_id": 1},
46-
{"index": 1, "taskset_id": 1},
4746
{"index": 0, "taskset_id": 0},
48-
{"index": 2, "taskset_id": 1},
47+
{"index": 1, "taskset_id": 1},
4948
{"index": 1, "taskset_id": 0},
49+
{"index": 2, "taskset_id": 1},
5050
{"index": 3, "taskset_id": 1},
51-
{"index": 4, "taskset_id": 1},
5251
{"index": 2, "taskset_id": 0},
53-
{"index": 5, "taskset_id": 1},
5452
{"index": 3, "taskset_id": 0},
53+
{"index": 4, "taskset_id": 1},
54+
{"index": 5, "taskset_id": 1},
55+
{"index": 6, "taskset_id": 1},
56+
{"index": 4, "taskset_id": 0},
5557
{"index": 0, "taskset_id": 1},
5658
{"index": 1, "taskset_id": 1},
57-
{"index": 2, "taskset_id": 1},
5859
{"index": 0, "taskset_id": 0},
60+
{"index": 2, "taskset_id": 1},
61+
{"index": 3, "taskset_id": 1},
5962
{"index": 1, "taskset_id": 0},
6063
{"index": 2, "taskset_id": 0},
61-
{"index": 3, "taskset_id": 1},
6264
{"index": 4, "taskset_id": 1},
63-
{"index": 5, "taskset_id": 1},
6465
{"index": 3, "taskset_id": 0},
66+
{"index": 5, "taskset_id": 1},
67+
{"index": 6, "taskset_id": 1},
68+
{"index": 4, "taskset_id": 0},
6569
],
6670
),
6771
(
6872
{"selector_type": "shuffle", "seed": 42},
6973
[
7074
{"index": 3, "taskset_id": 1},
71-
{"index": 2, "taskset_id": 1},
7275
{"index": 4, "taskset_id": 0},
73-
{"index": 5, "taskset_id": 1},
74-
{"index": 0, "taskset_id": 0},
76+
{"index": 2, "taskset_id": 1},
77+
{"index": 2, "taskset_id": 0},
7578
{"index": 6, "taskset_id": 1},
7679
{"index": 4, "taskset_id": 1},
7780
{"index": 3, "taskset_id": 0},
78-
{"index": 0, "taskset_id": 1},
79-
{"index": 2, "taskset_id": 0},
81+
{"index": 1, "taskset_id": 0},
8082
{"index": 1, "taskset_id": 1},
81-
{"index": 3, "taskset_id": 1},
83+
{"index": 5, "taskset_id": 1},
8284
{"index": 0, "taskset_id": 1},
83-
{"index": 2, "taskset_id": 0},
8485
{"index": 0, "taskset_id": 0},
85-
{"index": 1, "taskset_id": 0},
86+
{"index": 2, "taskset_id": 1},
8687
{"index": 6, "taskset_id": 1},
88+
{"index": 4, "taskset_id": 0},
8789
{"index": 5, "taskset_id": 1},
88-
{"index": 2, "taskset_id": 1},
90+
{"index": 1, "taskset_id": 1},
91+
{"index": 1, "taskset_id": 0},
92+
{"index": 2, "taskset_id": 0},
93+
{"index": 4, "taskset_id": 1},
94+
{"index": 0, "taskset_id": 0},
95+
{"index": 0, "taskset_id": 1},
96+
{"index": 3, "taskset_id": 1},
8997
{"index": 3, "taskset_id": 0},
9098
],
9199
),
92100
(
93101
{"selector_type": "random", "seed": 42},
94102
[
95103
{"index": 0, "taskset_id": 1},
96-
{"index": 5, "taskset_id": 1},
97104
{"index": 0, "taskset_id": 0},
105+
{"index": 3, "taskset_id": 1},
106+
{"index": 2, "taskset_id": 0},
98107
{"index": 4, "taskset_id": 1},
108+
{"index": 0, "taskset_id": 1},
99109
{"index": 2, "taskset_id": 0},
110+
{"index": 0, "taskset_id": 0},
100111
{"index": 6, "taskset_id": 1},
101112
{"index": 3, "taskset_id": 1},
102-
{"index": 3, "taskset_id": 0},
103-
{"index": 0, "taskset_id": 1},
104-
{"index": 4, "taskset_id": 0},
105-
{"index": 2, "taskset_id": 1},
106113
{"index": 0, "taskset_id": 1},
107-
{"index": 5, "taskset_id": 1},
108114
{"index": 2, "taskset_id": 0},
115+
{"index": 0, "taskset_id": 1},
116+
{"index": 2, "taskset_id": 1},
109117
{"index": 0, "taskset_id": 0},
110-
{"index": 3, "taskset_id": 0},
111118
{"index": 2, "taskset_id": 1},
112119
{"index": 6, "taskset_id": 1},
113-
{"index": 5, "taskset_id": 1},
114120
{"index": 0, "taskset_id": 0},
121+
{"index": 0, "taskset_id": 0},
122+
{"index": 5, "taskset_id": 1},
123+
{"index": 3, "taskset_id": 0},
124+
{"index": 2, "taskset_id": 1},
125+
{"index": 6, "taskset_id": 1},
126+
{"index": 1, "taskset_id": 0},
115127
],
116128
),
117129
(
118130
{"selector_type": "offline_easy2hard", "feature_keys": ["feature_offline"]},
119131
[
120132
{"index": 3, "taskset_id": 1},
121-
{"index": 4, "taskset_id": 1},
122133
{"index": 3, "taskset_id": 0},
123-
{"index": 1, "taskset_id": 1},
134+
{"index": 4, "taskset_id": 1},
124135
{"index": 0, "taskset_id": 0},
136+
{"index": 1, "taskset_id": 1},
125137
{"index": 0, "taskset_id": 1},
126-
{"index": 6, "taskset_id": 1},
127138
{"index": 2, "taskset_id": 0},
128-
{"index": 5, "taskset_id": 1},
129139
{"index": 4, "taskset_id": 0},
140+
{"index": 6, "taskset_id": 1},
141+
{"index": 5, "taskset_id": 1},
142+
{"index": 2, "taskset_id": 1},
143+
{"index": 1, "taskset_id": 0},
130144
{"index": 3, "taskset_id": 1},
131145
{"index": 4, "taskset_id": 1},
132-
{"index": 1, "taskset_id": 1},
133146
{"index": 3, "taskset_id": 0},
147+
{"index": 1, "taskset_id": 1},
148+
{"index": 0, "taskset_id": 1},
134149
{"index": 0, "taskset_id": 0},
135150
{"index": 2, "taskset_id": 0},
136-
{"index": 0, "taskset_id": 1},
137151
{"index": 6, "taskset_id": 1},
138-
{"index": 5, "taskset_id": 1},
139152
{"index": 4, "taskset_id": 0},
153+
{"index": 5, "taskset_id": 1},
154+
{"index": 2, "taskset_id": 1},
155+
{"index": 1, "taskset_id": 0},
140156
],
141157
),
142158
(
143159
{"selector_type": "diff_based", "feature_keys": ["feat_1", "feat_2"]},
144160
[
145161
{"index": 3, "taskset_id": 1},
146-
{"index": 0, "taskset_id": 1},
147162
{"index": 3, "taskset_id": 0},
148-
{"index": 2, "taskset_id": 1},
163+
{"index": 6, "taskset_id": 1},
149164
{"index": 2, "taskset_id": 0},
150-
{"index": 4, "taskset_id": 1},
151165
{"index": 2, "taskset_id": 1},
152-
{"index": 2, "taskset_id": 0},
153-
{"index": 6, "taskset_id": 1},
154-
{"index": 4, "taskset_id": 0},
155-
{"index": 4, "taskset_id": 1},
156166
{"index": 3, "taskset_id": 1},
167+
{"index": 2, "taskset_id": 0},
168+
{"index": 3, "taskset_id": 0},
169+
{"index": 2, "taskset_id": 1},
157170
{"index": 1, "taskset_id": 1},
171+
{"index": 4, "taskset_id": 1},
158172
{"index": 2, "taskset_id": 0},
173+
{"index": 3, "taskset_id": 1},
174+
{"index": 2, "taskset_id": 1},
159175
{"index": 4, "taskset_id": 0},
160-
{"index": 0, "taskset_id": 0},
161176
{"index": 4, "taskset_id": 1},
162-
{"index": 2, "taskset_id": 1},
163177
{"index": 5, "taskset_id": 1},
178+
{"index": 4, "taskset_id": 0},
164179
{"index": 3, "taskset_id": 0},
180+
{"index": 5, "taskset_id": 1},
181+
{"index": 1, "taskset_id": 0},
182+
{"index": 6, "taskset_id": 1},
183+
{"index": 6, "taskset_id": 1},
184+
{"index": 4, "taskset_id": 0},
165185
],
166186
),
167187
]

tests/cli/launcher_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_debug_mode(self, mock_load):
263263
except Exception:
264264
time.sleep(3)
265265
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
266-
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
266+
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
267267
mock_load.return_value = self.config
268268
with mock.patch(
269269
"argparse.ArgumentParser.parse_args",

tests/common/config_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_load_default_config(self):
3131
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.project)
3232
self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.name)
3333
self.assertEqual(
34-
config.buffer.explorer_input.taskset.repeat_times, config.algorithm.repeat_times
34+
config.buffer.explorer_input.tasksets[0].repeat_times, config.algorithm.repeat_times
3535
)
3636
self.assertEqual(config.model.model_path, config.model.critic_model_path)
3737
self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path)

trinity/buffer/reader/file_reader.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ def __init__(
4545
if total_steps:
4646
self.total_samples = default_batch_size * total_steps
4747
else:
48-
if drop_last:
49-
self.num_per_epoch = self.dataset_size - (self.dataset_size % default_batch_size)
50-
else:
51-
self.num_per_epoch = self.dataset_size
52-
self.total_samples = self.num_per_epoch * total_epochs
48+
self.total_samples = self.dataset_size * total_epochs
5349

5450
if enable_progress_bar:
5551
from ray.experimental.tqdm_ray import tqdm
@@ -68,26 +64,21 @@ def current_seed(self):
6864
return self.base_seed + self.current_offset // self.dataset_size
6965

7066
def read_batch(self, batch_size: int) -> Union[List, Iterable]:
71-
if self.current_offset >= self.total_samples:
72-
self.progress_bar.close()
73-
raise StopIteration
74-
start_epoch = self.current_offset // self.num_per_epoch
75-
start_index = self.current_offset % self.num_per_epoch
76-
77-
batch = []
78-
for i in range(start_index, start_index + batch_size):
79-
if i < self.num_per_epoch:
80-
batch.append(self.dataset[i])
81-
else:
82-
assert not self.drop_last
83-
break
67+
batch, indices = [], []
68+
while len(batch) < batch_size:
69+
if self.current_offset >= self.total_samples:
70+
if not self.drop_last and len(batch) > 0:
71+
break
72+
self.progress_bar.close()
73+
raise StopIteration
74+
index = self.current_offset % self.dataset_size
75+
batch.append(self.dataset[index])
76+
indices.append(index)
8477

8578
self.current_offset += len(batch)
8679
self.progress_bar.update(len(batch))
87-
if start_epoch != self.current_offset // self.num_per_epoch:
88-
assert self.current_offset % self.num_per_epoch == 0
8980

90-
return batch, range(start_index, self.current_offset)
81+
return batch, indices
9182

9283
def select_batch(self, indices: List[int]) -> List:
9384
batch = []
@@ -99,7 +90,7 @@ def select_batch(self, indices: List[int]) -> List:
9990

10091
class BaseFileReader(BufferReader):
10192
def __len__(self):
102-
return self.dataset.num_per_epoch
93+
return self.dataset.dataset_size
10394

10495
@property
10596
def index(self) -> int:

trinity/buffer/selector/selector.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,21 @@ class SequentialSelector(BaseSelector):
8383
"""
8484
Selects data sequentially in fixed order across epochs.
8585
86-
Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc., wrapping at epoch boundaries.
87-
Useful for deterministic iteration or when combined with external shuffling.
86+
Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc.
8887
"""
8988

9089
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
9190
super().__init__(data_source, config)
92-
self.num_per_epoch = data_source.num_per_epoch
91+
self.dataset_size = data_source.dataset_size
9392
self.current_index = 0
9493

9594
def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
96-
start = self.current_index % self.num_per_epoch
95+
start = self.current_index % self.dataset_size
9796
end = start + batch_size
98-
assert (
99-
end <= self.num_per_epoch
100-
), f"Batch size ({batch_size}) exceeds remaining data in epoch"
10197
self.current_index += batch_size
102-
return list(range(start, end))
98+
if end <= self.dataset_size:
99+
return list(range(start, end))
100+
return list(range(start, self.dataset_size)) + list(range(0, end - self.dataset_size))
103101

104102
def update(self, indices: List[int], values: List[float]) -> None:
105103
# No-op: sequential selection doesn't adapt based on feedback
@@ -119,41 +117,42 @@ class ShuffleSelector(BaseSelector):
119117
"""
120118
Shuffles dataset once per epoch and iterates through it sequentially.
121119
122-
Each epoch uses a different permutation of a subset of the full dataset
123-
(of size num_per_epoch). When one epoch ends, a new shuffle is triggered.
120+
Each epoch uses a different permutation of a subset of the full dataset.
121+
When one epoch ends, a new shuffle is triggered.
124122
Mimics standard PyTorch DataLoader with shuffle=True.
125123
"""
126124

127125
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
128126
super().__init__(data_source, config)
129127
self.dataset_size = data_source.dataset_size # Total available samples
130-
self.num_per_epoch = data_source.num_per_epoch # Samples used per epoch
131128
self.current_index = 0 # Progress tracker
132129
self.seed = config.seed # For reproducible shuffling
133-
self.order = self._get_order() # Current shuffled index order
130+
self.orders = self._get_orders() # Current shuffled index order
134131

135-
def _get_order(self) -> List[int]:
132+
def _get_orders(self) -> List[int]:
136133
"""
137134
Generate a new shuffled order for the current epoch.
138135
139136
Uses NumPy's PCG64 random generator seeded by epoch number for reproducibility.
140137
Ensures different shuffle per epoch while being deterministic if seed is fixed.
141138
"""
142-
rng = np.random.default_rng(self.seed + self.current_index // self.num_per_epoch)
143-
return rng.choice(self.dataset_size, self.num_per_epoch, replace=False)
139+
rng = np.random.default_rng(self.seed + self.current_index // self.dataset_size)
140+
return rng.permutation(self.dataset_size).tolist()
144141

145142
def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
146-
start = self.current_index % self.num_per_epoch
143+
start = self.current_index % self.dataset_size
147144
end = start + batch_size
148-
assert end <= self.num_per_epoch, f"Batch size ({batch_size}) is too large"
149-
150-
# Fetch pre-shuffled indices for this batch
151-
ret = self.order[start:end]
145+
if end <= self.dataset_size:
146+
ret = self.orders[start:end]
147+
# At end of epoch, reshuffle for next epoch
148+
if end == self.dataset_size:
149+
self.orders = self._get_orders()
150+
else:
151+
ret = self.orders[start:]
152+
# At end of epoch, reshuffle for next epoch
153+
self.orders = self._get_orders()
154+
ret += self.orders[: (end - self.dataset_size)]
152155
self.current_index += batch_size
153-
154-
# At end of epoch, reshuffle for next epoch
155-
if self.current_index % self.num_per_epoch == 0:
156-
self.order = self._get_order()
157156
return ret
158157

159158
def update(self, indices: List[int], values: List[float]) -> None:
@@ -167,7 +166,7 @@ def state_dict(self) -> Dict:
167166

168167
def load_state_dict(self, state_dict):
169168
self.current_index = state_dict.get("current_index", 0)
170-
self.order = self._get_order()
169+
self.orders = self._get_orders()
171170

172171

173172
@SELECTORS.register_module("random")
@@ -182,7 +181,6 @@ class RandomSelector(BaseSelector):
182181
def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig):
183182
super().__init__(data_source, config)
184183
self.dataset_size = data_source.dataset_size
185-
self.num_per_epoch = data_source.num_per_epoch
186184
self.current_index = 0
187185
self.seed = config.seed
188186

@@ -245,7 +243,7 @@ def __init__(self, data_source, config: DataSelectorConfig):
245243
self.sorted_index = np.array([i[-1] for i in features_with_index])
246244

247245
# Number of samples per epoch (may be less than full dataset size)
248-
self.num_per_epoch = data_source.num_per_epoch
246+
self.dataset_size = data_source.dataset_size
249247
self.current_index = 0
250248

251249
def update(self, indices: List[int], values: List[float]) -> None:
@@ -259,13 +257,15 @@ def get_indices(self, batch_size, return_extra_info=False):
259257
Batches are taken sequentially from the pre-sorted list. When epoch ends,
260258
it wraps around to the beginning (i.e., restarts curriculum).
261259
"""
262-
start = self.current_index % self.num_per_epoch
260+
start = self.current_index % self.dataset_size
263261
end = start + batch_size
264-
assert (
265-
end <= self.num_per_epoch
266-
), f"Batch size ({batch_size}) exceeds available data in epoch"
262+
if end <= self.dataset_size:
263+
selected_indices = self.sorted_index[start:end]
264+
else:
265+
selected_indices = np.concatenate(
266+
[self.sorted_index[start:], self.sorted_index[: (end - self.dataset_size)]]
267+
)
267268
self.current_index += batch_size
268-
selected_indices = self.sorted_index[start:end]
269269
if not return_extra_info:
270270
return selected_indices
271271
else:

0 commit comments

Comments
 (0)