diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 250ea3eb40..f25ee62fb3 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -112,10 +112,8 @@ def test_trainer(self): self.assertTrue(len(copy_countdown_metrics) > 0) countdown_metric_steps = parser.metric_steps(countdown_metrics[0]) countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0]) - self.assertEqual(2, len(countdown_metric_steps)) - self.assertEqual(2, len(countdown_copy_metric_steps)) - self.assertTrue(4 in countdown_metric_steps) - self.assertTrue(8 in countdown_metric_steps) + self.assertEqual([0, 4, 8], countdown_metric_steps) + self.assertEqual([0, 4, 8], countdown_copy_metric_steps) def tearDown(self): # remove dir only when the test passed diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 76869154ed..487801bf2a 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -6,7 +6,7 @@ import os import time from collections import defaultdict -from typing import List, Optional, Tuple +from typing import List, Optional import torch @@ -65,6 +65,7 @@ def __init__(self, config: Config): self.use_checkpoint_weights_update = ( self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT ) + self.eval_explore_step_num = None # For checkpoint weights update # Use explorer to periodically load the latest model weights and @@ -170,17 +171,26 @@ async def get_weight(self, name: str) -> torch.Tensor: return self.state_dict[name] async def explore(self) -> str: + """ + The dreamming loop for explorer and trainer. + | <----------------------------------------- one period ----------------------------------------------> | + explorer | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- eval --> | <-- [idle] --> | <-- sync --> | + trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- [idle] --> | <-- sync --> | + """ + self.eval_explore_step_num = None while True: try: + if ( + self.eval_explore_step_num is None + and self.explore_step_num % self.config.explorer.eval_interval == 0 + ): + self.eval_explore_step_num = self.explore_step_num explore_contionue = self.explore_step() if not explore_contionue: break if self.need_sync(): self.wait_for_workflow_done() await self.sync_weight() - if self.explore_step_num % self.config.explorer.eval_interval == 0: - self.wait_for_workflow_done() - self.eval() except Exception as e: self.logger.error(f"Error in Explorer: {e}") break @@ -216,16 +226,18 @@ def need_sync(self) -> bool: self.explore_step_num - self.config.synchronizer.sync_offset ) % self.config.synchronizer.sync_interval == 0 - def eval(self) -> Tuple[bool, int]: + def eval(self, eval_explore_step_num: int): """Evaluation on all evaluation data samples.""" if len(self.config.buffer.explorer_input.eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") - return True, self.explore_step_num - self.logger.info("Evaluation started.") + return + self.logger.info(f"Evaluation at step {eval_explore_step_num} started.") all_st = time.time() log_metrics = {} for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: - self.logger.info(f"Evaluation on {eval_taskset_config.name} started.") + self.logger.info( + f"Evaluation on {eval_taskset_config.name} at step {eval_explore_step_num} started." + ) eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer) st = time.time() all_metrics = defaultdict(list) @@ -254,18 +266,19 @@ def wait(): log_metrics.update(metrics) log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st log_metrics["eval/total_time"] = time.time() - all_st - self.monitor.log(log_metrics, step=self.explore_step_num) # type: ignore - self.logger.info("Evaluation finished.") - return True, self.explore_step_num + self.monitor.log(log_metrics, step=eval_explore_step_num) # type: ignore + self.logger.info(f"Evaluation at step {eval_explore_step_num} finished.") async def benchmark(self) -> bool: """Benchmark the model checkpoints.""" # benchmark on the latest checkpoint if self.config.explorer.eval_on_latest_checkpoint: await self._checkpoint_weights_update() - self.eval() + self.eval(self.explore_step_num) return True + # benchmark on base model + self.eval(0) # benchmark on all checkoints all_ckp_steps = sorted( [ @@ -276,9 +289,8 @@ async def benchmark(self) -> bool: ] ) for step_num in all_ckp_steps: - self.explore_step_num = step_num await self._checkpoint_weights_update(step_num=step_num) - self.eval() + self.eval(step_num) return True def wait_for_workflow_done(self) -> None: @@ -302,6 +314,10 @@ def wait_for_workflow_done(self) -> None: else: for metric_name, metric_value in status.metric.items(): all_metrics[metric_name].append(metric_value) + # eval + if self.eval_explore_step_num is not None: + self.eval(self.eval_explore_step_num) + self.eval_explore_step_num = None # calculate metrics log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore self.monitor.log(log_metrics, step=self.explore_step_num)