diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 0dd846d8ec..21780e0e03 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -13,6 +13,20 @@ logger = get_logger(__name__) +def bench(config: Config) -> None: + """Evaluate model.""" + explorer = Explorer.remote(config) + try: + ray.get(explorer.prepare.remote()) + ray.get(explorer.sync_weight.remote()) + _, step = ray.get(explorer.eval.remote()) + logger.info("Evaluation finished.") + ray.get(explorer.flush_log.remote(step=step)) + except Exception as e: + logger.error(f"Evaluation failed: {e}") + raise e + + def explore(config: Config) -> None: """Run explorer.""" explorer = Explorer.remote(config) @@ -151,6 +165,8 @@ def run(config_path: str): train(config) elif config.mode == "both": both(config) + elif config.mode == "bench": + bench(config) def studio(port: int = 8501): diff --git a/trinity/common/config.py b/trinity/common/config.py index 0970337071..a7a7b67d78 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -238,7 +238,7 @@ class SynchronizerConfig: class Config: """Global Configuration""" - mode: str = "both" # `explore`, `train` or `both` + mode: str = "both" # `explore`, `train`, `both` or `bench` data: DataConfig = field(default_factory=DataConfig) model: ModelConfig = field(default_factory=ModelConfig) cluster: ClusterConfig = field(default_factory=ClusterConfig) @@ -302,7 +302,7 @@ def _check_buffer(self) -> None: def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" # check mode - if self.mode not in ["explore", "train", "both"]: + if self.mode not in ["explore", "train", "both", "bench"]: raise ValueError(f"Invalid mode: {self.mode}") if self.trainer.algorithm_type == AlgorithmType.DPO and self.mode == "both": raise ValueError("DPO does not support `both` mode") @@ -325,6 +325,11 @@ def check_and_update(self) -> None: # noqa: C901 self.explorer.engine_num * self.explorer.tensor_parallel_size ) self.synchronizer.backend = self.explorer.backend + if self.mode == "bench" and self.synchronizer.sync_method != SyncMethod.CHECKPOINT: + self.synchronizer.sync_method = "checkpoint" + logger.warning( + "Bench mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) if ( self.trainer.algorithm_type == AlgorithmType.DPO and self.synchronizer.sync_method != SyncMethod.CHECKPOINT diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index ac8bf6ce21..93a3c82d9e 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -318,9 +318,10 @@ def synchronize_config(self, config: Config) -> None: self.actor_rollout_ref.actor.use_kl_loss = True logger.warning("DPO must use KL loss.") logger.warning("DPO micro batch size is doubled for computing loss.") - self.actor_rollout_ref.actor.ppo_mini_batch_size *= 2 self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore + if self.actor_rollout_ref.rollout.n != 2: + self.actor_rollout_ref.rollout.n = 2 # TODO: check other fields self.enable_preview = config.trainer.enable_preview diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 49a0b60e6f..784d7b526c 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -228,11 +228,11 @@ def explore_one_period(self) -> Tuple[bool, int]: self.logger.info(f"Explore step {self.step_num} finished.") return True, self.step_num - def eval(self) -> bool: + def eval(self) -> Tuple[bool, int]: """Evaluation on all evaluation data samples.""" if self.eval_taskset is None: self.logger.warning("No evaluation data samples. Skip evaluation.") - return True + return True, self.step_num self.logger.info("Evaluation started.") st = time.time() all_metrics = defaultdict(list) @@ -255,7 +255,7 @@ def eval(self) -> bool: log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore log_metrics["eval/total_time"] = time.time() - st self.monitor.log(log_metrics, step=self.step_num) # type: ignore - return True + return True, self.step_num def sync_weight(self) -> None: """Synchronize model weights.""" diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 1ea242ea10..607b01dc7d 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -73,8 +73,20 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool bool: Whether to continue training. """ self.engine.set_mode(algo_type) + if algo_type.is_rft() and self.config.trainer.get_exp_strategy: + strategy = ReadStrategy(self.config.trainer.get_exp_strategy) + else: + strategy = None + try: + if algo_type.is_sft(): + exps = self.sft_warmup_buffer.read() + else: + exps = self.train_buffer.read(strategy=strategy) + except StopIteration: + self.logger.warning("No more data to train. Stop training.") + return False, 0 # TODO: get the actual step number + if algo_type.is_sft(): - exps = self.sft_warmup_buffer.read() return self.engine.train_sft_step( Experiences.gather_experiences( exps, @@ -82,15 +94,6 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool ) ) elif algo_type.is_rft(): - if self.config.trainer.get_exp_strategy: - strategy = ReadStrategy(self.config.trainer.get_exp_strategy) - else: - strategy = None - try: - exps = self.train_buffer.read(strategy=strategy) - except StopIteration: - self.logger.warning("No more data to train. Stop training.") - return False, 0 # TODO: get the actual step number return self.engine.train_rft_step( Experiences.gather_experiences( exps, @@ -98,7 +101,6 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool ) ) elif algo_type.is_dpo(): - exps = self.train_buffer.read() return self.engine.train_dpo_step( Experiences.gather_dpo_experiences( exps,