Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ def test_trainer(self):
self.assertTrue(len(copy_countdown_metrics) > 0)
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
self.assertEqual(2, len(countdown_metric_steps))
self.assertEqual(2, len(countdown_copy_metric_steps))
self.assertTrue(4 in countdown_metric_steps)
self.assertTrue(8 in countdown_metric_steps)
self.assertEqual([0, 4, 8], countdown_metric_steps)
self.assertEqual([0, 4, 8], countdown_copy_metric_steps)

def tearDown(self):
# remove dir only when the test passed
Expand Down
44 changes: 30 additions & 14 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import time
from collections import defaultdict
from typing import List, Optional, Tuple
from typing import List, Optional

import torch

Expand Down Expand Up @@ -65,6 +65,7 @@ def __init__(self, config: Config):
self.use_checkpoint_weights_update = (
self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
)
self.eval_explore_step_num = None

# For checkpoint weights update
# Use explorer to periodically load the latest model weights and
Expand Down Expand Up @@ -170,17 +171,26 @@ async def get_weight(self, name: str) -> torch.Tensor:
return self.state_dict[name]

async def explore(self) -> str:
"""
The dreamming loop for explorer and trainer.
| <----------------------------------------- one period ----------------------------------------------> |
explorer | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- eval --> | <-- [idle] --> | <-- sync --> |
trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- [idle] --> | <-- sync --> |
"""
self.eval_explore_step_num = None
while True:
try:
if (
self.eval_explore_step_num is None
and self.explore_step_num % self.config.explorer.eval_interval == 0
):
self.eval_explore_step_num = self.explore_step_num
explore_contionue = self.explore_step()
if not explore_contionue:
break
if self.need_sync():
self.wait_for_workflow_done()
await self.sync_weight()
if self.explore_step_num % self.config.explorer.eval_interval == 0:
self.wait_for_workflow_done()
self.eval()
except Exception as e:
self.logger.error(f"Error in Explorer: {e}")
break
Expand Down Expand Up @@ -216,16 +226,18 @@ def need_sync(self) -> bool:
self.explore_step_num - self.config.synchronizer.sync_offset
) % self.config.synchronizer.sync_interval == 0

def eval(self) -> Tuple[bool, int]:
def eval(self, eval_explore_step_num: int):
"""Evaluation on all evaluation data samples."""
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return True, self.explore_step_num
self.logger.info("Evaluation started.")
return
self.logger.info(f"Evaluation at step {eval_explore_step_num} started.")
all_st = time.time()
log_metrics = {}
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.logger.info(f"Evaluation on {eval_taskset_config.name} started.")
self.logger.info(
f"Evaluation on {eval_taskset_config.name} at step {eval_explore_step_num} started."
)
eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer)
st = time.time()
all_metrics = defaultdict(list)
Expand Down Expand Up @@ -254,18 +266,19 @@ def wait():
log_metrics.update(metrics)
log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st
log_metrics["eval/total_time"] = time.time() - all_st
self.monitor.log(log_metrics, step=self.explore_step_num) # type: ignore
self.logger.info("Evaluation finished.")
return True, self.explore_step_num
self.monitor.log(log_metrics, step=eval_explore_step_num) # type: ignore
self.logger.info(f"Evaluation at step {eval_explore_step_num} finished.")

async def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.explorer.eval_on_latest_checkpoint:
await self._checkpoint_weights_update()
self.eval()
self.eval(self.explore_step_num)
return True

# benchmark on base model
self.eval(0)
# benchmark on all checkoints
all_ckp_steps = sorted(
[
Expand All @@ -276,9 +289,8 @@ async def benchmark(self) -> bool:
]
)
for step_num in all_ckp_steps:
self.explore_step_num = step_num
await self._checkpoint_weights_update(step_num=step_num)
self.eval()
self.eval(step_num)
return True

def wait_for_workflow_done(self) -> None:
Expand All @@ -302,6 +314,10 @@ def wait_for_workflow_done(self) -> None:
else:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)
# eval
if self.eval_explore_step_num is not None:
self.eval(self.eval_explore_step_num)
self.eval_explore_step_num = None
# calculate metrics
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
self.monitor.log(log_metrics, step=self.explore_step_num)
Expand Down