diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 1378449cf2..19aac1b571 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -10,6 +10,7 @@ import ray +from trinity.algorithm import SAMPLE_STRATEGY from trinity.common.config import Config from trinity.common.constants import RunningStatus, SyncMethod from trinity.utils.log import get_logger @@ -23,6 +24,11 @@ def __init__(self, config: Config) -> None: self.logger = get_logger(__name__) self.engine = get_trainer_wrapper(config) self.explorer_ref = None + self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)( + buffer_config=config.buffer, + trainer_type=config.trainer.trainer_type, + **config.algorithm.sample_strategy_args, + ) def prepare(self) -> None: """Prepare the trainer.""" @@ -32,7 +38,30 @@ def train(self) -> str: """Train the model.""" while True: try: - train_continue = self.train_step() + # sample experiences for train step + try: + batch, sample_metrics, exp_samples = self.sample_strategy.sample( + self.engine.global_steps + 1, + ) + successful_sampling = True + except StopIteration: + print("No more data to train. Stop training.") + if ( + self.engine.config.trainer.save_freq == 0 + or self.engine.global_steps % self.engine.config.trainer.save_freq != 0 + ): # TODO: double-check this if-condition + self.engine.logger.info(f"Saving at step {self.engine.global_steps}.") + self.engine.save_checkpoint() + self.engine.logger.info(f"Saved at step {self.engine.global_steps}.") + successful_sampling = False + # TODO: get rid of self.engine.global_steps/config/logger? + + # run train step + if successful_sampling: + train_continue = self.engine.train_step(batch, sample_metrics, exp_samples) + else: + train_continue = False + if not train_continue: break if self.need_sync(): @@ -43,14 +72,6 @@ def train(self) -> str: self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name - def train_step(self) -> bool: - """Train one step. - - Returns: - bool: Whether to continue training. - """ - return self.engine.train_step() - def need_sync(self) -> bool: """Whether to sync the model weight.""" return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 @@ -95,7 +116,7 @@ def train_step_num(self) -> int: """Get the current training step number.""" @abstractmethod - def train_step(self) -> bool: + def train_step(self, batch, sample_metrics, exp_samples) -> bool: """Training.""" @abstractmethod diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 2a8b1d0135..9d48421b43 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -31,7 +31,7 @@ from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs -from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY +from trinity.algorithm import ADVANTAGE_FN, KL_FN from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.algorithm.utils import prefix_metrics @@ -134,11 +134,11 @@ def __init__( self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)( **self.algorithm_config.kl_penalty_fn_args ) - self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)( - buffer_config=global_config.buffer, - trainer_type=global_config.trainer.trainer_type, - **global_config.algorithm.sample_strategy_args, - ) + # self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)( + # buffer_config=global_config.buffer, + # trainer_type=global_config.trainer.trainer_type, + # **global_config.algorithm.sample_strategy_args, + # ) super().__init__( config, tokenizer, @@ -287,22 +287,25 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - def train_step(self) -> bool: # noqa C901 + def train_step(self, batch, sample_metrics, exp_samples) -> bool: # noqa C901 self.logger.info(f"Training at step {self.global_steps + 1} started.") metrics = {} - try: - batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) - prefix_metrics(sample_metrics, "sample", metrics) - except StopIteration: - print("No more data to train. Stop training.") - if ( - self.config.trainer.save_freq == 0 - or self.global_steps % self.config.trainer.save_freq != 0 - ): - self.logger.info(f"Saving at step {self.global_steps}.") - self._save_checkpoint() - self.logger.info(f"Saved at step {self.global_steps}.") - return False + prefix_metrics(sample_metrics, "sample", metrics) + + # try: + # batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) + # prefix_metrics(sample_metrics, "sample", metrics) + # except StopIteration: + # print("No more data to train. Stop training.") + # if ( + # self.config.trainer.save_freq == 0 + # or self.global_steps % self.config.trainer.save_freq != 0 + # ): + # self.logger.info(f"Saving at step {self.global_steps}.") + # self._save_checkpoint() + # self.logger.info(f"Saved at step {self.global_steps}.") + # return False + self.global_steps += 1 self.logger.info(f"Sampling at step {self.global_steps} done.") timing_raw = {}