Skip to content

Commit ba33438

Browse files
authored
Bug fix when set total_steps. (#386)
1 parent af7f8aa commit ba33438

File tree

2 files changed

+70
-11
lines changed

2 files changed

+70
-11
lines changed

tests/buffer/task_scheduler_test.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,54 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
4141
@parameterized.expand(
4242
[
4343
(
44+
{"batch_size": 5, "total_steps": 3},
45+
{"selector_type": "sequential"},
46+
[
47+
{"index": 0, "taskset_id": 1},
48+
{"index": 1, "taskset_id": 1},
49+
{"index": 2, "taskset_id": 1},
50+
{"index": 0, "taskset_id": 0},
51+
{"index": 1, "taskset_id": 0},
52+
{"index": 3, "taskset_id": 1},
53+
{"index": 4, "taskset_id": 1},
54+
{"index": 5, "taskset_id": 1},
55+
{"index": 2, "taskset_id": 0},
56+
{"index": 3, "taskset_id": 0},
57+
{"index": 6, "taskset_id": 1},
58+
{"index": 0, "taskset_id": 1},
59+
{"index": 1, "taskset_id": 1},
60+
{"index": 4, "taskset_id": 0},
61+
{"index": 0, "taskset_id": 0},
62+
],
63+
),
64+
(
65+
{"batch_size": 5, "total_epochs": 2},
66+
{"selector_type": "sequential"},
67+
[
68+
{"index": 0, "taskset_id": 1},
69+
{"index": 1, "taskset_id": 1},
70+
{"index": 2, "taskset_id": 1},
71+
{"index": 0, "taskset_id": 0},
72+
{"index": 1, "taskset_id": 0},
73+
{"index": 3, "taskset_id": 1},
74+
{"index": 4, "taskset_id": 1},
75+
{"index": 5, "taskset_id": 1},
76+
{"index": 2, "taskset_id": 0},
77+
{"index": 3, "taskset_id": 0},
78+
{"index": 6, "taskset_id": 1},
79+
{"index": 0, "taskset_id": 1},
80+
{"index": 1, "taskset_id": 1},
81+
{"index": 4, "taskset_id": 0},
82+
{"index": 0, "taskset_id": 0},
83+
{"index": 2, "taskset_id": 1},
84+
{"index": 3, "taskset_id": 1},
85+
{"index": 4, "taskset_id": 1},
86+
{"index": 1, "taskset_id": 0},
87+
{"index": 2, "taskset_id": 0},
88+
],
89+
),
90+
(
91+
{"batch_size": 2, "total_epochs": 2},
4492
{"selector_type": "sequential"},
4593
[
4694
{"index": 0, "taskset_id": 1},
@@ -70,6 +118,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
70118
],
71119
),
72120
(
121+
{"batch_size": 2, "total_epochs": 2},
73122
{"selector_type": "shuffle", "seed": 42},
74123
[
75124
{"index": 3, "taskset_id": 1},
@@ -99,6 +148,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
99148
],
100149
),
101150
(
151+
{"batch_size": 2, "total_epochs": 2},
102152
{"selector_type": "random", "seed": 42},
103153
[
104154
{"index": 0, "taskset_id": 1},
@@ -128,6 +178,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
128178
],
129179
),
130180
(
181+
{"batch_size": 2, "total_epochs": 2},
131182
{"selector_type": "offline_easy2hard", "feature_keys": ["feature_offline"]},
132183
[
133184
{"index": 3, "taskset_id": 1},
@@ -157,6 +208,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
157208
],
158209
),
159210
(
211+
{"batch_size": 2, "total_epochs": 2},
160212
{"selector_type": "difficulty_based", "feature_keys": ["feat_1", "feat_2"]},
161213
[
162214
{"index": 3, "taskset_id": 1},
@@ -187,10 +239,13 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
187239
),
188240
]
189241
)
190-
async def test_task_scheduler(self, task_selector_kwargs, batch_tasks_orders) -> None:
242+
async def test_task_scheduler(
243+
self, buffer_config_kwargs, task_selector_kwargs, batch_tasks_orders
244+
) -> None:
191245
config = get_template_config()
192-
config.buffer.batch_size = 2
193-
config.buffer.total_epochs = 2
246+
config.mode = "explore"
247+
for key, value in buffer_config_kwargs.items():
248+
setattr(config.buffer, key, value)
194249
config.buffer.explorer_input.taskset = None
195250
config.buffer.explorer_input.tasksets = [
196251
TasksetConfig(

trinity/buffer/task_scheduler.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def __init__(self, explorer_state: Dict, config: Config):
9090
self.epoch = self.step * self.read_batch_size // len(self.base_taskset_ids)
9191
self.orders = self.build_orders(self.epoch)
9292

93+
if self.config.buffer.total_steps:
94+
self.max_steps = self.config.buffer.total_steps
95+
else:
96+
self.max_steps = (
97+
self.config.buffer.total_epochs * len(self.base_taskset_ids) // self.read_batch_size
98+
)
99+
93100
def build_orders(self, epoch: int):
94101
"""
95102
Creates a shuffled sequence of taskset IDs to control sampling priority per step.
@@ -108,6 +115,9 @@ def build_orders(self, epoch: int):
108115
rng.shuffle(taskset_ids)
109116
return taskset_ids
110117

118+
def _should_stop(self) -> bool:
119+
return self.step >= self.max_steps
120+
111121
async def read_async(self) -> List:
112122
"""
113123
Asynchronously reads a batch of tasks according to the current schedule.
@@ -125,12 +135,8 @@ async def read_async(self) -> List:
125135
Returns:
126136
List[Task]: A batch of tasks from potentially multiple tasksets
127137
"""
128-
if self.config.buffer.total_steps:
129-
if self.step >= self.config.buffer.total_steps:
130-
raise StopAsyncIteration
131-
else:
132-
if self.epoch >= self.config.buffer.total_epochs:
133-
raise StopAsyncIteration
138+
if self._should_stop():
139+
raise StopAsyncIteration
134140

135141
batch_size = self.read_batch_size
136142
start = self.step * batch_size % len(self.base_taskset_ids)
@@ -143,8 +149,6 @@ async def read_async(self) -> List:
143149
else:
144150
taskset_ids = self.orders[start:]
145151
self.epoch += 1
146-
if self.epoch >= self.config.buffer.total_epochs:
147-
raise StopAsyncIteration
148152
self.orders = self.build_orders(self.epoch)
149153
taskset_ids += self.orders[: (end - len(self.base_taskset_ids))]
150154

0 commit comments

Comments
 (0)