Skip to content

Commit 81330f8

Browse files
authored
Refactor explorer loop (#101)
1 parent 5be7436 commit 81330f8

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

tests/trainer/trainer_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,8 @@ def test_trainer(self):
112112
self.assertTrue(len(copy_countdown_metrics) > 0)
113113
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
114114
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
115-
self.assertEqual(2, len(countdown_metric_steps))
116-
self.assertEqual(2, len(countdown_copy_metric_steps))
117-
self.assertTrue(4 in countdown_metric_steps)
118-
self.assertTrue(8 in countdown_metric_steps)
115+
self.assertEqual([0, 4, 8], countdown_metric_steps)
116+
self.assertEqual([0, 4, 8], countdown_copy_metric_steps)
119117

120118
def tearDown(self):
121119
# remove dir only when the test passed

trinity/explorer/explorer.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import time
88
from collections import defaultdict
9-
from typing import List, Optional, Tuple
9+
from typing import List, Optional
1010

1111
import torch
1212

@@ -65,6 +65,7 @@ def __init__(self, config: Config):
6565
self.use_checkpoint_weights_update = (
6666
self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
6767
)
68+
self.eval_explore_step_num = None
6869

6970
# For checkpoint weights update
7071
# Use explorer to periodically load the latest model weights and
@@ -170,17 +171,26 @@ async def get_weight(self, name: str) -> torch.Tensor:
170171
return self.state_dict[name]
171172

172173
async def explore(self) -> str:
174+
"""
175+
The dreamming loop for explorer and trainer.
176+
| <----------------------------------------- one period ----------------------------------------------> |
177+
explorer | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- eval --> | <-- [idle] --> | <-- sync --> |
178+
trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- [idle] --> | <-- sync --> |
179+
"""
180+
self.eval_explore_step_num = None
173181
while True:
174182
try:
183+
if (
184+
self.eval_explore_step_num is None
185+
and self.explore_step_num % self.config.explorer.eval_interval == 0
186+
):
187+
self.eval_explore_step_num = self.explore_step_num
175188
explore_contionue = self.explore_step()
176189
if not explore_contionue:
177190
break
178191
if self.need_sync():
179192
self.wait_for_workflow_done()
180193
await self.sync_weight()
181-
if self.explore_step_num % self.config.explorer.eval_interval == 0:
182-
self.wait_for_workflow_done()
183-
self.eval()
184194
except Exception as e:
185195
self.logger.error(f"Error in Explorer: {e}")
186196
break
@@ -216,16 +226,18 @@ def need_sync(self) -> bool:
216226
self.explore_step_num - self.config.synchronizer.sync_offset
217227
) % self.config.synchronizer.sync_interval == 0
218228

219-
def eval(self) -> Tuple[bool, int]:
229+
def eval(self, eval_explore_step_num: int):
220230
"""Evaluation on all evaluation data samples."""
221231
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
222232
self.logger.warning("No evaluation data samples. Skip evaluation.")
223-
return True, self.explore_step_num
224-
self.logger.info("Evaluation started.")
233+
return
234+
self.logger.info(f"Evaluation at step {eval_explore_step_num} started.")
225235
all_st = time.time()
226236
log_metrics = {}
227237
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
228-
self.logger.info(f"Evaluation on {eval_taskset_config.name} started.")
238+
self.logger.info(
239+
f"Evaluation on {eval_taskset_config.name} at step {eval_explore_step_num} started."
240+
)
229241
eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer)
230242
st = time.time()
231243
all_metrics = defaultdict(list)
@@ -254,18 +266,19 @@ def wait():
254266
log_metrics.update(metrics)
255267
log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st
256268
log_metrics["eval/total_time"] = time.time() - all_st
257-
self.monitor.log(log_metrics, step=self.explore_step_num) # type: ignore
258-
self.logger.info("Evaluation finished.")
259-
return True, self.explore_step_num
269+
self.monitor.log(log_metrics, step=eval_explore_step_num) # type: ignore
270+
self.logger.info(f"Evaluation at step {eval_explore_step_num} finished.")
260271

261272
async def benchmark(self) -> bool:
262273
"""Benchmark the model checkpoints."""
263274
# benchmark on the latest checkpoint
264275
if self.config.explorer.eval_on_latest_checkpoint:
265276
await self._checkpoint_weights_update()
266-
self.eval()
277+
self.eval(self.explore_step_num)
267278
return True
268279

280+
# benchmark on base model
281+
self.eval(0)
269282
# benchmark on all checkoints
270283
all_ckp_steps = sorted(
271284
[
@@ -276,9 +289,8 @@ async def benchmark(self) -> bool:
276289
]
277290
)
278291
for step_num in all_ckp_steps:
279-
self.explore_step_num = step_num
280292
await self._checkpoint_weights_update(step_num=step_num)
281-
self.eval()
293+
self.eval(step_num)
282294
return True
283295

284296
def wait_for_workflow_done(self) -> None:
@@ -302,6 +314,10 @@ def wait_for_workflow_done(self) -> None:
302314
else:
303315
for metric_name, metric_value in status.metric.items():
304316
all_metrics[metric_name].append(metric_value)
317+
# eval
318+
if self.eval_explore_step_num is not None:
319+
self.eval(self.eval_explore_step_num)
320+
self.eval_explore_step_num = None
305321
# calculate metrics
306322
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
307323
self.monitor.log(log_metrics, step=self.explore_step_num)

0 commit comments

Comments
 (0)