diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index ab28f9802a..0e01f879d3 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -40,7 +40,7 @@ Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pa We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following: -We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`. The value of `sync_iteration_interval` can be set as same of the value of `save_interval`. +We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`. ```yaml # In dpo.yaml @@ -50,7 +50,6 @@ synchronizer: buffer: train_dataset: storage_type: file - algorithm_type: dpo path: <$DATASET_PATH/human_like_dpo_dataset> kwargs: prompt_type: # messages/plaintext diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md index 9307ea9e9f..68ce684b81 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md @@ -20,7 +20,7 @@ To try out the OPMD algorithm: trinity run --config examples/opmd_gsm8k/opmd_gsm8k.yaml ``` -Note that in this config file, `sync_iteration_interval` is set to 10, i.e., the model weights of explorer and trainer are synchronized only once every 10 training steps, which leads to a challenging off-policy scenario (potentially with abrupt distribution shift during the RFT process). +Note that in this config file, `sync_interval` is set to 10, i.e., the model weights of explorer and trainer are synchronized only once every 10 training steps, which leads to a challenging off-policy scenario (potentially with abrupt distribution shift during the RFT process). Other configurations of particular interest are explained at the beginning of [`train_opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/opmd_gsm8k/train_opmd_gsm8k.yaml). @@ -30,7 +30,7 @@ Other configurations of particular interest are explained at the beginning of [` The red curve below shows an example of OPMD's learning curves. Since the explorer's model weights remain unchanged for the first 10 steps, its score remains flat. Then, after the model weights of explorer and trainer are synchronized at the end of step 10, we see an abrupt increase in score at step 11, which indicates effective off-policy learning in the first 10 steps. -A similar performance boost is shown at step 21, which leads to a converged score matching what is achieved by GRPO in a mostly on-policy case (with `sync_iteration_interval=2`). +A similar performance boost is shown at step 21, which leads to a converged score matching what is achieved by GRPO in a mostly on-policy case (with `sync_interval=2`). diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index ba0b013cdd..e9a6d9b594 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -37,13 +37,13 @@ More details on dataset downloading are referred to [ModelScope](https://modelsc ### Synchronous Mode of Trinity-RFT -We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_iteration_interval` properly. A smaller value of `sync_iteration_interval` makes the training closer to an on-policy setup. +We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup. ```yaml mode: both synchronizer: sync_method: 'nccl' - sync_iteration_interval: 2 + sync_interval: 2 ``` ### Use GRPO or PPO Algorithm @@ -76,21 +76,20 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml ## Optional: RFT with SFT Warmup -Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_iteration > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`. +Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_steps > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`. ```yaml # Properly set the following configs in gsm8k.yaml buffer: sft_warmup_dataset: storage_type: file - algorithm_type: sft path: <$DATASET_PATH/{sft_data}> kwargs: prompt_type: # messages/plaintext/chatpair prompt_key: response_key: trainer: - sft_warmup_iteration: 10 + sft_warmup_steps: 10 ``` The following command runs SFT and RFT in sequence: diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index de6f26252d..6983162cc7 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -49,7 +49,7 @@ data: - `data.max_retry_times`: The maximum number of retries when loading the dataset from database. - `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database. - `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually. -- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `actor_rollout_ref.rollout.n` Default is `1`. It should be set manually. +- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `explorer.repeat_times`. It should be set manually. - `data.default_workflow_type`: The default workflow type used for training. - `data.default_reward_fn_type`: The default reward function type used for training. @@ -150,14 +150,14 @@ explorer: ```yaml synchronizer: sync_method: 'nccl' - sync_iteration_interval: 10 + sync_interval: 10 sync_timeout: 1200 ``` - `synchronizer.sync_method`: The synchronization method between `trainer` and `explorer`. Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explorer` will be synchronized from `trainer` through `nccl`, `checkpoint` represents that `explorer` will load the newest checkpoints saved by `trainer` then update its model weights. Default is `nccl`. -- `synchronizer.sync_iteration_interval`: The interval between two synchronizations. Default is `10`. It should be set manually. +- `synchronizer.sync_interval`: The interval between two synchronizations. Default is `10`. It should be set manually. - `synchronizer.sync_timeout`: The timeout of the synchronization. Default is `1200`. ## Trainer @@ -167,7 +167,7 @@ trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' - sft_warmup_iteration: 0 + sft_warmup_steps: 0 eval_interval: 1000 save_interval: 100 ``` @@ -175,7 +175,7 @@ trainer: - `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported. - `trainer.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`. - `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually. -- `trainer.sft_warmup_iteration`: The number of iterations to warm up the model. Default is `0`. +- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`. - `trainer.eval_interval`: The interval between two evaluations. Default is `1000`. - `trainer.save_interval`: The interval between two checkpoints. Default is `100`. @@ -418,7 +418,7 @@ trainer: - `trainer.balance_batch`: Whether to balance batch size between GPUs during training. - `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`. - `trainer.resume_from_path`: Path to resume from. -- `trainer.critic_warmup`: The number of iteration to train the critic model before actual policy learning. +- `trainer.critic_warmup`: The number of steps to train the critic model before actual policy learning. - `trainer.default_hdfs_dir`: Default HDFS directory for saving checkpoints. - `trainer.remove_previous_ckpt_in_save`: Whether to remove previous checkpoints in save. - `trainer.del_local_ckpt_after_load`: Whether to delete local checkpoints after loading. diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml index 73f8669b10..82ad223baa 100644 --- a/examples/async_gsm8k/explorer.yaml +++ b/examples/async_gsm8k/explorer.yaml @@ -48,9 +48,9 @@ synchronizer: sync_iteration_interval: 10 trainer: trainer_type: 'verl' - algorithm_type: ppo + algorithm_type: grpo trainer_config_path: examples/async_gsm8k/verl_config.yaml - sft_warmup_iteration: 0 # Set to integer to enable sft warmup + sft_warmup_steps: 0 # Set to integer to enable sft warmup eval_interval: 10 monitor: cache_root_dir: "" diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml index d2dc92503a..e67d325ca2 100644 --- a/examples/async_gsm8k/trainer.yaml +++ b/examples/async_gsm8k/trainer.yaml @@ -48,9 +48,9 @@ synchronizer: sync_iteration_interval: 10 trainer: trainer_type: 'verl' - algorithm_type: ppo + algorithm_type: grpo trainer_config_path: examples/async_gsm8k/verl_config.yaml - sft_warmup_iteration: 0 # Set to integer to enable sft warmup + sft_warmup_steps: 0 # Set to integer to enable sft warmup eval_interval: 10 monitor: cache_root_dir: "" diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 5d03d7130c..6763254dd9 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -46,7 +46,7 @@ explorer: max_waiting_steps: 4 synchronizer: sync_method: 'checkpoint' - sync_iteration_interval: 30 + sync_interval: 30 sync_timeout: 1200 trainer: trainer_type: 'verl' diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 1881f78d36..6c70a12d75 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -39,14 +39,14 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 gpu_memory_utilization: 0.7 - enable_chunked_prefil: true + enable_chunked_prefill: true synchronizer: sync_method: 'nccl' - sync_iteration_interval: 8 + sync_interval: 8 sync_timeout: 1200 trainer: trainer_type: 'verl' - algorithm_type: ppo + algorithm_type: grpo trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml' save_interval: 10 monitor: diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 9dd620c0d7..a5ea536bff 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -60,13 +60,13 @@ explorer: max_waiting_steps: 4 synchronizer: sync_method: 'nccl' - sync_iteration_interval: 2 + sync_interval: 2 sync_timeout: 1200 trainer: trainer_type: 'verl' - algorithm_type: ppo + algorithm_type: grpo trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml' - sft_warmup_iteration: 0 # Set to integer to enable sft warmup + sft_warmup_steps: 0 # Set to integer to enable sft warmup eval_interval: 50 save_interval: 100 # get_exp_strategy: 'LFU' diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index db6a347bc9..c1d1bc0f15 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -46,13 +46,13 @@ explorer: max_waiting_steps: 4 synchronizer: sync_method: 'nccl' - sync_iteration_interval: 2 + sync_interval: 2 sync_timeout: 1200 trainer: trainer_type: 'verl' - algorithm_type: ppo + algorithm_type: grpo trainer_config_path: 'examples/grpo_math/train_math.yaml' - sft_warmup_iteration: 0 # Set to integer to enable sft warmup + sft_warmup_steps: 0 # Set to integer to enable sft warmup eval_interval: 10 save_interval: 100 monitor: diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index 53dbdea801..1b85571e23 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -39,14 +39,14 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 gpu_memory_utilization: 0.7 - enable_chunked_prefil: true + enable_chunked_prefill: true synchronizer: sync_method: 'nccl' - sync_iteration_interval: 8 + sync_interval: 8 sync_timeout: 1200 trainer: trainer_type: 'verl' - algorithm_type: ppo + algorithm_type: grpo trainer_config_path: 'examples/grpo_sciworld/train_sciworld.yaml' save_interval: 10 monitor: diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml index a301140c07..7bdebcf2fa 100644 --- a/examples/grpo_webshop/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -39,14 +39,14 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 gpu_memory_utilization: 0.7 - enable_chunked_prefil: true + enable_chunked_prefill: true synchronizer: sync_method: 'nccl' - sync_iteration_interval: 8 + sync_interval: 8 sync_timeout: 1200 trainer: trainer_type: 'verl' - algorithm_type: ppo + algorithm_type: grpo trainer_config_path: 'examples/grpo_webshop/train_webshop.yaml' save_interval: 10 monitor: diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml index d5d60f7126..35a2cfe169 100644 --- a/examples/opmd_gsm8k/opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -39,13 +39,13 @@ explorer: max_waiting_steps: 4 synchronizer: sync_method: 'nccl' - sync_iteration_interval: 10 + sync_interval: 10 sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: opmd trainer_config_path: 'examples/opmd_gsm8k/train_opmd_gsm8k.yaml' - sft_warmup_iteration: 0 + sft_warmup_steps: 0 save_interval: 100 monitor: cache_root_dir: "" diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml index 033405f5c8..88f92fb461 100644 --- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml @@ -68,8 +68,8 @@ actor_rollout_ref: # min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program - beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_iteration_interval) - beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_iteration_interval) + beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) + beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval) fsdp_config: wrap_policy: # transformer_layer_cls_to_wrap: None diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml index acc9c7950e..f1c1b4b31d 100644 --- a/examples/ppo_countdown/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -42,13 +42,13 @@ explorer: max_waiting_steps: 4 synchronizer: sync_method: 'nccl' - sync_iteration_interval: 10 + sync_interval: 10 sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' - sft_warmup_iteration: 0 + sft_warmup_steps: 0 eval_interval: 1000 save_interval: 100 monitor: diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 67a372fb5b..001a592c42 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -19,7 +19,7 @@ def test_load_default_config(self): self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project) self.assertEqual( config.trainer.trainer_config.trainer.save_freq, - config.synchronizer.sync_iteration_interval, + config.synchronizer.sync_interval, ) def test_all_examples_are_valid(self): diff --git a/tests/data/core/dataset_test.py b/tests/data/core/dataset_test.py index ebb711b955..522abb13a2 100644 --- a/tests/data/core/dataset_test.py +++ b/tests/data/core/dataset_test.py @@ -82,12 +82,12 @@ def test_to_taskset(self): def test_to_taskset_with_global_settings(self): dataset = RftDataset(data_config=self.data_config, reward_schema="default") taskset = dataset.to_taskset( - reward_fn=AccuracyReward(), + reward_fn=AccuracyReward, workflow=SimpleWorkflow, ) self.assertIsInstance(taskset, TaskSet) self.assertEqual(taskset.workflow, SimpleWorkflow) - self.assertIsInstance(taskset.reward_fn, AccuracyReward) + self.assertEqual(taskset.reward_fn, AccuracyReward) def test_to_taskset_with_sample_level_settings(self): dataset = RftDataset( @@ -97,22 +97,22 @@ def test_to_taskset_with_sample_level_settings(self): self.assertIsInstance(taskset, TaskSet) for task in taskset.tasks: self.assertEqual(task.workflow, MathWorkflow) - self.assertIsInstance(task.reward_fn, AccuracyReward) + self.assertEqual(task.reward_fn, AccuracyReward) def test_to_taskset_with_both_settings(self): dataset = RftDataset( data_config=self.data_config_sample_level_setting, reward_schema="default" ) taskset = dataset.to_taskset( - reward_fn=AccuracyReward(), + reward_fn=AccuracyReward, workflow=SimpleWorkflow, ) self.assertIsInstance(taskset, TaskSet) for task in taskset.tasks: self.assertEqual(task.workflow, MathWorkflow) - self.assertIsInstance(task.reward_fn, AccuracyReward) + self.assertEqual(task.reward_fn, AccuracyReward) self.assertEqual(taskset.workflow, SimpleWorkflow) - self.assertIsInstance(taskset.reward_fn, AccuracyReward) + self.assertEqual(taskset.reward_fn, AccuracyReward) if __name__ == "__main__": diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 7d1fbd4c88..99ec5a739c 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -24,7 +24,7 @@ def setUp(self): self.config.monitor.monitor_type = MonitorType.TENSORBOARD self.config.monitor.project = "Trinity-unittest" self.config.model.checkpoint_path = get_checkpoint_path() - self.config.synchronizer.sync_iteration_interval = 2 + self.config.synchronizer.sync_interval = 2 self.config.explorer.eval_interval = 4 self.config.trainer.eval_interval = 4 diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 76cde18dc0..1f52409477 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -36,7 +36,7 @@ explorer: trainer: trainer_type: verl trainer_config_path: tests/template/verl_config.yaml - sft_warmup_iteration: 0 + sft_warmup_steps: 0 eval_interval: 1000 save_interval: 100 monitor: @@ -44,6 +44,6 @@ monitor: name: test synchronizer: sync_method: checkpoint - sync_iteration_interval: 10 + sync_interval: 10 sync_timeout: 1200 wait_for_checkpoint: false diff --git a/tests/tools.py b/tests/tools.py index 0aacf295cd..1bbcc767ef 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -33,7 +33,7 @@ def get_checkpoint_path() -> str: def get_unittest_dataset_config(dataset_name: str = "countdown") -> DataConfig: - """Countdown sample dataset for 8 iterations""" + """Countdown sample dataset for 8 steps""" if dataset_name == "countdown": return DataConfig( total_epochs=2, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 4bb1f88685..9e68f71d70 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -29,7 +29,7 @@ def setUp(self): self.config.model.checkpoint_path = os.path.join( get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}" ) - self.config.synchronizer.sync_iteration_interval = 2 + self.config.synchronizer.sync_interval = 2 self.config.synchronizer.sync_method = SyncMethod.NCCL self.config.explorer.eval_interval = 4 self.config.trainer.eval_interval = 4 @@ -61,12 +61,12 @@ def test_trainer(self): self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) # check checkpoint - from trinity.common.models.utils import get_checkpoint_dir_with_iteration + from trinity.common.models.utils import get_checkpoint_dir_with_step_num - checkpoint_dir = get_checkpoint_dir_with_iteration( + checkpoint_dir = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.model.checkpoint_path, trainer_type=self.config.trainer.trainer_type, - iteration_num=None, + step_num=None, ) self.assertTrue(os.path.exists(checkpoint_dir)) self.assertTrue(checkpoint_dir.endswith("step_8")) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index dda42ecaad..2cf706b661 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -32,10 +32,15 @@ def train(config: Config) -> None: trainer = Trainer.remote(config) ray.get(trainer.prepare.remote()) - if config.trainer.sft_warmup_iteration > 0: - for step in range(config.trainer.sft_warmup_iteration): - ray.get(trainer.train_step.remote(AlgorithmType.SFT)) - logger.info(f"SFT warmup step {step} finished.") + if config.trainer.sft_warmup_steps > 0: + while True: + train_continue, train_step_num = ray.get( + trainer.train_one_period.remote(AlgorithmType.SFT) + ) + logger.info(f"SFT warmup step {train_step_num} finished.") + if not train_continue: + logger.info("SFT warmup finished.") + break algo_type = config.trainer.algorithm_type try: @@ -49,7 +54,7 @@ def train(config: Config) -> None: def both(config: Config) -> None: """Setup both explorer and trainer. - For the explorer, a step contains `batch_size * sync_iteration_interval` number + For the explorer, a step contains `batch_size * sync_interval` number of rollout tasks. For the trainer, it has to consume all experiences generated by the explorer in @@ -69,19 +74,24 @@ def both(config: Config) -> None: # sync weight before training start ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - if config.trainer.sft_warmup_iteration > 0: - for step in range(config.trainer.sft_warmup_iteration): - ray.get(trainer.train_step.remote(AlgorithmType.SFT)) - logger.info(f"SFT warmup step {step} finished.") + if config.trainer.sft_warmup_steps > 0: + while True: + train_continue, train_step_num = ray.get( + trainer.train_one_period.remote(AlgorithmType.SFT) + ) + logger.info(f"SFT warmup step {train_step_num} finished.") + if not train_continue: + logger.info("SFT warmup finished.") + break ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) algo_type = config.trainer.algorithm_type while True: try: - ref_explore = explorer.explore_step.remote() - ref_train = trainer.train_step.remote(algo_type) - explore_continue, explore_iter_num = ray.get(ref_explore) - train_continue, train_iter_num = ray.get(ref_train) + ref_explore = explorer.explore_one_period.remote() + ref_train = trainer.train_one_period.remote(algo_type) + explore_continue, explore_step_num = ray.get(ref_explore) + train_continue, train_step_num = ray.get(ref_train) if not explore_continue: # If explore finished, the trainer may not have enough experiences to continue, # which will cause the trainer be blocked. So we stop the training process @@ -98,7 +108,7 @@ def both(config: Config) -> None: logger.error(e) logger.error("Training stopped due to exception.") raise e - if train_iter_num % config.trainer.eval_interval == 0: + if train_step_num % config.trainer.eval_interval == 0: try: ray.get(explorer.eval.remote()) logger.info("Evaluation finished.") @@ -106,8 +116,8 @@ def both(config: Config) -> None: logger.error(e) logger.error("Evaluation failed.") raise e - ray.get(explorer.flush_log.remote(step=explore_iter_num)) - ray.get(trainer.flush_log.remote(step=train_iter_num)) + ray.get(explorer.flush_log.remote(step=explore_step_num)) + ray.get(trainer.flush_log.remote(step=train_step_num)) def activate_data_module(data_workflow_url: str, config_path: str): diff --git a/trinity/common/config.py b/trinity/common/config.py index d4bb08b27d..0970337071 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -172,7 +172,7 @@ class ExplorerConfig: backend: str = "nccl" use_ray: bool = False gpu_memory_utilization: float = 0.9 - enable_chunked_prefil: bool = False + enable_chunked_prefill: bool = False use_v1: bool = True bundle_indices: str = "" # DO NOT SET this field @@ -193,11 +193,12 @@ class TrainerConfig: trainer_config: Any = field(default_factory=dict) # train algorithm - algorithm_type: AlgorithmType = AlgorithmType.PPO + algorithm_type: AlgorithmType = AlgorithmType.PPO # automatically set get_exp_strategy: Optional[str] = None # warmup config - sft_warmup_iteration: int = 0 + sft_warmup_steps: int = 0 + sft_warmup_iteration: Optional[int] = None # deprecated @dataclass @@ -220,8 +221,10 @@ class SynchronizerConfig: # TODO: rename to "checkpoint", "nccl", "ipc" sync_method: SyncMethod = SyncMethod.NCCL - # sync weights every `sync_iteration_interval` iterations - sync_iteration_interval: int = 1 + # sync weights every `sync_interval` steps + sync_interval: int = 1 + # `sync_iteration_interval` is deprecated, use `sync_interval` instead + sync_iteration_interval: Optional[int] = None sync_timeout: int = 1200 # wait for the lastest checkpoint to be ready wait_for_checkpoint: bool = False @@ -251,9 +254,9 @@ def save(self, config_path: str) -> None: OmegaConf.save(self, f) def _check_buffer(self) -> None: - if self.trainer.sft_warmup_iteration > 0 and self.buffer.sft_warmup_dataset is None: + if self.trainer.sft_warmup_steps > 0 and self.buffer.sft_warmup_dataset is None: raise ValueError( - "buffer.sft_warmup_dataset is required when trainer.sft_warmup_iteration > 0" + "buffer.sft_warmup_dataset is required when trainer.sft_warmup_steps > 0" ) if self.buffer.db_url: raise ValueError( @@ -277,7 +280,6 @@ def _check_buffer(self) -> None: self.buffer.train_dataset = DatasetConfig( name="experience_buffer", storage_type=StorageType.QUEUE, - algorithm_type=self.trainer.algorithm_type, ) logger.info(f"Auto set `buffer.train_dataset` to {self.buffer.train_dataset}") else: # TODO: to be check @@ -286,12 +288,11 @@ def _check_buffer(self) -> None: self.buffer.train_dataset = DatasetConfig( name="dpo_train_dataset", storage_type=StorageType.FILE, - algorithm_type=self.trainer.algorithm_type, ) logger.info(f"Auto set `buffer.train_dataset` to {self.buffer.train_dataset}") if self.buffer.train_dataset is None: raise ValueError("buffer.train_dataset is required when mode is not 'both'") - self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type + self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type self.buffer.train_dataset.namespace = f"{self.monitor.project}-{self.monitor.name}" if self.buffer.sft_warmup_dataset is not None: self.buffer.sft_warmup_dataset.namespace = f"{self.monitor.project}-{self.monitor.name}" @@ -307,41 +308,63 @@ def check_and_update(self) -> None: # noqa: C901 raise ValueError("DPO does not support `both` mode") # check model path - if not os.path.isabs(self.model.model_path): - self.model.model_path = os.path.join(os.getcwd(), self.model.model_path) if not os.path.isabs(self.model.checkpoint_path): self.model.checkpoint_path = os.path.join(os.getcwd(), self.model.checkpoint_path) if not self.model.critic_model_path: self.model.critic_model_path = self.model.model_path # check synchronizer - assert self.synchronizer.sync_iteration_interval > 0 + if self.synchronizer.sync_iteration_interval is not None: + logger.warning( + f"`synchronizer.sync_iteration_interval` is deprecated, please use `synchronizer.sync_interval` instead. " + f"And `synchronizer.sync_interval` will set to {self.synchronizer.sync_iteration_interval} instead." + ) + self.synchronizer.sync_interval = self.synchronizer.sync_iteration_interval + assert self.synchronizer.sync_interval > 0 self.synchronizer.explorer_world_size = ( self.explorer.engine_num * self.explorer.tensor_parallel_size ) self.synchronizer.backend = self.explorer.backend + if ( + self.trainer.algorithm_type == AlgorithmType.DPO + and self.synchronizer.sync_method != SyncMethod.CHECKPOINT + ): + self.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) if self.synchronizer.sync_method == SyncMethod.NCCL and self.mode != "both": raise ValueError("`nccl` synchronization is only supported in both mode.") # check eval_interval - if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0: + if ( + self.trainer.algorithm_type != AlgorithmType.DPO + and self.trainer.eval_interval % self.synchronizer.sync_interval != 0 + ): self.trainer.eval_interval = ( - max(self.trainer.eval_interval // self.synchronizer.sync_iteration_interval, 1) - ) * self.synchronizer.sync_iteration_interval - print( - f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}." + max(self.trainer.eval_interval // self.synchronizer.sync_interval, 1) + ) * self.synchronizer.sync_interval + logger.warning( + f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.trainer.eval_interval}." ) if self.explorer.eval_interval != self.trainer.eval_interval: self.explorer.eval_interval = self.trainer.eval_interval - print( - f"Warning: explorer.eval_interval is not equal to trainer.eval_interval; adjusted to the same value={self.trainer.eval_interval}." + logger.warning( + f"`explorer.eval_interval` is not equal to `trainer.eval_interval`; adjusted to the same value={self.trainer.eval_interval}." ) # check save_interval - if self.synchronizer.sync_method == SyncMethod.CHECKPOINT: - self.trainer.save_interval = ( - self.synchronizer.sync_iteration_interval - ) # TODO: not proper for DPO + if ( + self.trainer.algorithm_type != AlgorithmType.DPO + and self.synchronizer.sync_method == SyncMethod.CHECKPOINT + ): + if self.trainer.save_interval != self.synchronizer.sync_interval: + logger.warning( + f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, " + f"`trainer.save_interval` will be set to " + f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." + ) + self.trainer.save_interval = self.synchronizer.sync_interval # check monitor if not self.monitor.cache_root_dir: @@ -359,6 +382,13 @@ def check_and_update(self) -> None: # noqa: C901 f"your checkpoint path: {self.model.checkpoint_path}" ) + if self.trainer.sft_warmup_iteration is not None: + logger.warning( + f"`trainer.sft_warmup_iteration` is deprecated, please use `trainer.sft_warmup_steps` instead. " + f"And `trainer.sft_warmup_steps` will be set to {self.trainer.sft_warmup_iteration} instead." + ) + self.trainer.sft_warmup_steps = self.trainer.sft_warmup_iteration + # check buffer self._check_buffer() # check and update trainer diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 38d9a9c162..7fcc7b4f23 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -2,6 +2,10 @@ """Constants.""" from enum import Enum, EnumMeta +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + # names ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync" @@ -101,15 +105,15 @@ class MonitorType(CaseInsensitiveEnum): class SyncMethodEnumMeta(CaseInsensitiveEnumMeta): def __call__(cls, value, *args, **kwargs): if value == "online": - print("SyncMethod `online` is deprecated, use `nccl` instead.") + logger.warning("SyncMethod `online` is deprecated, use `nccl` instead.") value = "nccl" elif value == "offline": - print("SyncMethod `offline` is deprecated, use `checkpoint` instead.") + logger.warning("SyncMethod `offline` is deprecated, use `checkpoint` instead.") value = "checkpoint" try: return super().__call__(value, *args, **kwargs) except Exception as e: - print("Error parsing SyncMethod:", e) + logger.warning("Error parsing SyncMethod:", e) raise ValueError(f"Invalid SyncMethod: {value}") diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 36f8f9843c..a8751e7240 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -101,22 +101,20 @@ def tokenize_and_mask_messages_default( return (tokens[0], assistant_token_mask) -def get_checkpoint_dir_with_iteration( +def get_checkpoint_dir_with_step_num( checkpoint_root_path: str, trainer_type: str = "verl", - iteration_num: Optional[int] = None, + step_num: Optional[int] = None, ) -> str: """Get the checkpoint directory from a root checkpoint directory. Args: checkpoint_root_path (str): The root checkpoint directory. trainer_type (str): The trainer type. Only support "verl" for now. - iteration_num (Optional[int], optional): The iteration number. Defaults to None. + step_num (Optional[int], optional): The step number. Defaults to None. """ if trainer_type == "verl": - return get_verl_checkpoint_dir( - checkpoint_path=checkpoint_root_path, iteration_num=iteration_num - ) + return get_verl_checkpoint_dir(checkpoint_path=checkpoint_root_path, step_num=step_num) else: raise NotImplementedError(f"Unsupported trainer type {trainer_type}") @@ -146,9 +144,9 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): raise ValueError(f"Unsupported placement: {placement}") -def get_verl_checkpoint_dir(checkpoint_path: str, iteration_num: Optional[int] = None) -> str: +def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None) -> str: """Get the checkpoint directory from a Verl root checkpoint directory.""" - if iteration_num is None: + if step_num is None: # load latest checkpoint iteration_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt") if os.path.exists(iteration_file): @@ -162,7 +160,7 @@ def get_verl_checkpoint_dir(checkpoint_path: str, iteration_num: Optional[int] = raise FileNotFoundError(f"No iteration file found in {checkpoint_path}") else: # load specific iteration checkpoint - return os.path.join(checkpoint_path, f"global_step_{iteration_num}") + return os.path.join(checkpoint_path, f"global_step_{step_num}") # copy from verl/scripts/model_merger.py diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index fc257a6845..4134161ef5 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -80,7 +80,7 @@ def __init__( task="generate", disable_log_requests=True, gpu_memory_utilization=config.explorer.gpu_memory_utilization, - enable_chunked_prefill=config.explorer.enable_chunked_prefil, + enable_chunked_prefill=config.explorer.enable_chunked_prefill, # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage ) self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) @@ -122,12 +122,20 @@ async def chat_async(self, messages: List[Dict], **kwargs) -> List[Experience]: self.tokenizer = await self.async_llm.get_tokenizer() if self.chat_template is None: self.chat_template = self.tokenizer.get_chat_template() - prompt = self.tokenizer.apply_chat_template( - messages, - chat_template=self.chat_template, - tokenize=False, - add_generation_prompt=True, - ) + if messages[-1]["role"] == "assistant": + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + continue_final_message=True, + chat_template=self.chat_template, + ) + else: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=self.chat_template, + ) return await self.generate_async(prompt=prompt, **kwargs) async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index b993b82586..0964484e00 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -65,7 +65,7 @@ def __init__(self, config: Config, **kwargs): dtype=config.explorer.dtype, trust_remote_code=True, gpu_memory_utilization=config.explorer.gpu_memory_utilization, - enable_chunked_prefill=config.explorer.enable_chunked_prefil, + enable_chunked_prefill=config.explorer.enable_chunked_prefill, # max_num_batched_tokens=256, **kwargs, ) @@ -220,12 +220,20 @@ def chat(self, messages: List[dict], **kwargs) -> List[Experience]: List[Experience]: A list of experiences containing the response text. """ # TODO: support tools and other fields - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - chat_template=self.chat_template, - ) + if messages[-1]["role"] == "assistant": + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + continue_final_message=True, + chat_template=self.chat_template, + ) + else: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=self.chat_template, + ) return self.generate([prompt], **kwargs) def logprobs(self, token_ids: List[int]) -> torch.Tensor: diff --git a/trinity/common/task.py b/trinity/common/task.py index 5f5309565e..781e755739 100644 --- a/trinity/common/task.py +++ b/trinity/common/task.py @@ -26,7 +26,7 @@ class Task: task_desc: str workflow: Type[Workflow] - reward_fn: Optional[RewardFn] = None + reward_fn: Optional[Type[RewardFn]] = None truth: Optional[str] = None raw: Optional[dict] = None # The raw data sample task_type: Optional[TaskType] = None @@ -62,7 +62,7 @@ def task_generator( start_index: int, config: DataConfig, default_workflow: Optional[Type[Workflow]], - default_reward_fn: Optional[RewardFn], + default_reward_fn: Optional[Type[RewardFn]], task_type: Optional[TaskType], ) -> Iterator[Task]: """Get a generator of tasks from the dataset.""" @@ -116,7 +116,7 @@ class TaskSet: dataset: Any # the source huggingface dataset config: DataConfig - reward_fn: Optional[RewardFn] = None + reward_fn: Optional[Type[RewardFn]] = None workflow: Optional[Type[Workflow]] = None task_type: Optional[TaskType] = None default_index: int = 0 @@ -151,12 +151,11 @@ def load( dataset_len = len(dataset) default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) - default_reward_instance = default_reward_fn_cls() if default_reward_fn_cls else None return cls( dataset=dataset, config=config, workflow=default_workflow_cls, - reward_fn=default_reward_instance, + reward_fn=default_reward_fn_cls, task_type=task_type, default_index=latest_task_index % dataset_len, default_epoch=latest_task_index // dataset_len, diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index c46fefb849..ac8bf6ce21 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -5,6 +5,10 @@ from trinity.common.config import BufferConfig, Config, SynchronizerConfig from trinity.common.constants import AlgorithmType +from trinity.trainer.verl.ray_trainer import AdvantageEstimator +from trinity.utils.log import get_logger + +logger = get_logger(__name__) @dataclass @@ -245,7 +249,7 @@ class Trainer: training_rollout_mode: str = "parallel" enable_exp_buffer: bool = True sync_freq: int = 0 - sft_warmup_iteration: int = 0 + sft_warmup_steps: int = 0 max_actor_ckpt_to_keep: Optional[int] = None max_critic_ckpt_to_keep: Optional[int] = None @@ -278,7 +282,7 @@ def synchronize_config(self, config: Config) -> None: else: # for multi-node scenarios, some nodes for rollout, others for training self.trainer.n_gpus_per_node = config.cluster.gpu_per_node - self.trainer.sync_freq = config.synchronizer.sync_iteration_interval + self.trainer.sync_freq = config.synchronizer.sync_interval self.trainer.save_freq = config.trainer.save_interval self.synchronizer = config.synchronizer self.actor_rollout_ref.synchronizer = config.synchronizer @@ -289,20 +293,31 @@ def synchronize_config(self, config: Config) -> None: f"batch_size ({config.data.batch_size}) must be divisible by ({world_size})" ) # TODO: use dynamic read_batch_size to support multi-round scenarios - # Get the experiences of one explore iteration + # Get the experiences of one explore step self.buffer.pad_token_id = config.buffer.pad_token_id self.trainer.project_name = config.monitor.project self.trainer.experiment_name = config.monitor.name self.data.train_batch_size = config.data.batch_size self.trainer.default_local_dir = config.model.checkpoint_path - self.trainer.sft_warmup_iteration = config.trainer.sft_warmup_iteration + self.trainer.sft_warmup_steps = config.trainer.sft_warmup_steps self.actor_rollout_ref.actor.ppo_mini_batch_size = config.data.batch_size self.actor_rollout_ref.rollout.temperature = config.explorer.temperature self.actor_rollout_ref.rollout.n = config.explorer.repeat_times + self.critic.ppo_mini_batch_size = config.data.batch_size + self.actor_rollout_ref.actor.algorithm_type = config.trainer.algorithm_type + if config.trainer.algorithm_type == AlgorithmType.PPO: + logger.info("Using GAE `adv_estimator` for PPO") + self.algorithm.adv_estimator = AdvantageEstimator.GAE.value + elif config.trainer.algorithm_type == AlgorithmType.GRPO: + logger.info("Using GRPO `adv_estimator` for GRPO") + self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO - print("Warning: DPO micro batch size is doubled for computing loss.") + if not self.actor_rollout_ref.actor.use_kl_loss: + self.actor_rollout_ref.actor.use_kl_loss = True + logger.warning("DPO must use KL loss.") + logger.warning("DPO micro batch size is doubled for computing loss.") self.actor_rollout_ref.actor.ppo_mini_batch_size *= 2 self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 170a1971e6..a7e45e4a7f 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -101,7 +101,7 @@ def __init__(self, model: ModelWrapper, **kwargs): self.system_prompt = kwargs.get("system_prompt", None) # Unuse here self.task_desc: str = kwargs.get("task_desc") self.truth = kwargs.get("truth") # Unuse here - self.reward_fn = kwargs.get("reward_fn") # Unuse here + self.reward_fn = None # Unuse here self.repeat_times = kwargs.get("repeat_times", 1) self.max_env_steps = 30 diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index 4892abb97f..06a0748f4c 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -64,7 +64,7 @@ def __init__(self, model: ModelWrapper, **kwargs): self.system_prompt = kwargs.get("system_prompt", None) # Unuse here self.task_desc: str = kwargs.get("task_desc") self.truth = kwargs.get("truth") # Unuse here - self.reward_fn = kwargs.get("reward_fn") # Unuse here + self.reward_fn = None # Unuse here self.repeat_times = kwargs.get("repeat_times", 1) self.max_env_steps = 30 # should be less than 100 diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 5773f8e6e8..741fcd4b05 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -186,7 +186,7 @@ def __init__(self, model: ModelWrapper, **kwargs): self.system_prompt = kwargs.get("system_prompt", None) # Unuse here self.task_desc: str = kwargs.get("task_desc", "0") self.truth = kwargs.get("truth") # Unuse here - self.reward_fn = kwargs.get("reward_fn") # Unuse here + self.reward_fn = None # Unuse here self.repeat_times = kwargs.get("repeat_times", 1) self.max_env_steps = 15 diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 1e82b2dfd3..40fc7aa3ce 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -8,7 +8,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.rewards.reward_fn import MathRewardFn +from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn from trinity.utils.log import get_logger from trinity.utils.registry import Registry @@ -85,22 +85,27 @@ def __init__( ): super().__init__(model) self.system_prompt = kwargs.get("system_prompt", None) + self.reply_prefix = kwargs.get("reply_prefix", None) # TODO: add reply_prefix self.task_desc = kwargs.get("task_desc") self.truth = kwargs.get("truth") self.reward_fn = kwargs.get("reward_fn") + if isinstance(self.reward_fn, type) and issubclass(self.reward_fn, RewardFn): + self.reward_fn = self.reward_fn() + else: + raise ValueError("`reward_fn` must be a subclass of `RewardFn`") # Rollout n times self.repeat_times = kwargs.get("repeat_times", 1) self.is_eval = kwargs.get("is_eval", False) def run(self) -> List[Experience]: # TODO: Optimize the generate function + messages = [] if self.system_prompt: - messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.task_desc}, - ] - else: - messages = [{"role": "user", "content": self.task_desc}] + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": self.task_desc}) + if self.reply_prefix: + messages.append({"role": "assistant", "content": self.reply_prefix}) + logger.debug("start chat") n = 1 if self.is_eval else self.repeat_times responses = self.model.chat(messages, n=n) @@ -132,7 +137,7 @@ def __init__( **kwargs, ): if kwargs.get("reward_fn", None) is None: - kwargs["reward_fn"] = MathRewardFn() + kwargs["reward_fn"] = MathRewardFn kwargs[ "system_prompt" ] = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py index aad081e874..de6fa7281d 100644 --- a/trinity/data/core/dataset.py +++ b/trinity/data/core/dataset.py @@ -68,12 +68,11 @@ def format( def to_taskset(self, **kwargs) -> TaskSet: default_workflow_cls = WORKFLOWS.get(self.config.default_workflow_type) default_reward_fn_cls = REWARD_FUNCTIONS.get(self.config.default_reward_fn_type) - default_reward_instance = default_reward_fn_cls() if default_reward_fn_cls else None return TaskSet( dataset=self.data, config=self.config, workflow=default_workflow_cls, - reward_fn=default_reward_instance, + reward_fn=default_reward_fn_cls, ) def to_parquet(self, path: str): diff --git a/trinity/data/readme.md b/trinity/data/readme.md index e331a5726f..db1b3f443b 100644 --- a/trinity/data/readme.md +++ b/trinity/data/readme.md @@ -53,7 +53,7 @@ dataset.format([ # convert to a task set with global reward function and workflow task_set = dataset.to_taskset( - reward_fn=AccuracyReward(), + reward_fn=AccuracyReward, workflow=MathWorkflow, ) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index a52dc2b14b..49a0b60e6f 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -17,7 +17,7 @@ ) from trinity.common.models import create_rollout_models from trinity.common.models.utils import ( - get_checkpoint_dir_with_iteration, + get_checkpoint_dir_with_step_num, load_state_dict, ) from trinity.common.task import TaskSet @@ -35,7 +35,7 @@ def __init__(self, config: Config): self.logger = get_logger(__name__) self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() - self.iteration = explorer_meta.get("latest_iteration", 0) + self.step_num = explorer_meta.get("latest_iteration", 0) self.config = config self.models = create_rollout_models(config) self.experience_buffer = get_buffer_writer( @@ -60,9 +60,7 @@ def __init__(self, config: Config): self.max_pending_task_num = self.config.explorer.runner_num self.max_waiting_steps = max(1, int(self.config.explorer.max_waiting_steps)) self.batch_size = config.data.batch_size - self.update_interval = ( - self.config.synchronizer.sync_iteration_interval * self.config.data.batch_size - ) + self.update_interval = self.config.synchronizer.sync_interval * self.config.data.batch_size self.use_checkpoint_weights_update = ( self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT ) @@ -129,13 +127,13 @@ def _update_model_weight(self, state_dict: dict) -> None: ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models]) self.state_dict.clear() - def _checkpoint_weights_update(self, iteration_num: Optional[int] = None) -> None: + def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: # TODO: support more checkpoint types try: - checkpoint_dir = get_checkpoint_dir_with_iteration( + checkpoint_dir = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.model.checkpoint_path, trainer_type=self.config.trainer.trainer_type, - iteration_num=iteration_num, + step_num=step_num, ) if checkpoint_dir == self.old_checkpoint: return @@ -162,7 +160,7 @@ def get_weight(self, name: str) -> torch.Tensor: def explore(self) -> None: """Explore the entire dataset.""" while True: - explore_status, explore_iter = self.explore_step() + explore_status, explore_iter = self.explore_one_period() if not explore_status: break self.sync_weight() @@ -171,33 +169,31 @@ def explore(self) -> None: self.logger.info("Evaluation finished.") self.logger.info("Explorer finished.") - def explore_step(self) -> Tuple[bool, int]: - """Explore for one step. + def explore_one_period(self) -> Tuple[bool, int]: + """Explore for one period. Different from `explore()` which consumes all tasks in the task set, - `explore_step()` only consume `sync_iteration_interval * batch_size` + `explore_one_period()` only consume `sync_interval * batch_size` number of tasks. - explore_status: + Returns: explore_status: whether there are more tasks to explore. - explore_iter_num: the number of explore iterations + explore_step_num: the number of explore steps """ if self.task_iter is None: self.task_iter = iter(self.taskset) - task_num_per_step = ( - self.config.synchronizer.sync_iteration_interval * self.config.data.batch_size - ) + task_num_per_period = self.config.synchronizer.sync_interval * self.config.data.batch_size st = time.time() all_metrics = defaultdict(list) # submit tasks of this step try: - tasks = [next(self.task_iter) for _ in range(task_num_per_step)] # type: ignore + tasks = [next(self.task_iter) for _ in range(task_num_per_period)] # type: ignore self.runner_pool.run_tasks(tasks) except StopIteration: self.experience_buffer.finish() self.logger.warning("No more tasks in the task set. Stop exploring.") - return False, self.iteration + return False, self.step_num # wait for all tasks of this step to finish while self.runner_pool.has_next(): @@ -212,7 +208,7 @@ def explore_step(self) -> Tuple[bool, int]: self.runner_pool.run_tasks(next(self.task_iter)) # type: ignore except StopIteration: self.logger.warning("No more tasks in the task set. Stop exploring.") - return False, self.iteration + return False, self.step_num else: for metric_name, metric_value in status.metric.items(): all_metrics[metric_name].append(metric_value) @@ -220,17 +216,17 @@ def explore_step(self) -> Tuple[bool, int]: # calculate metrics log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore log_metrics["rollout/step_time"] = time.time() - st - self.iteration += self.config.synchronizer.sync_iteration_interval - self.monitor.log(log_metrics, step=self.iteration) + self.step_num += self.config.synchronizer.sync_interval + self.monitor.log(log_metrics, step=self.step_num) # save explore checkpoint self.cache.save_explorer( - current_iteration=self.iteration, + current_step=self.step_num, current_task_index=self.taskset.index, ) - self.logger.info(f"Explore iteration {self.iteration} finished.") - return True, self.iteration + self.logger.info(f"Explore step {self.step_num} finished.") + return True, self.step_num def eval(self) -> bool: """Evaluation on all evaluation data samples.""" @@ -258,7 +254,7 @@ def eval(self) -> bool: log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore log_metrics["eval/total_time"] = time.time() - st - self.monitor.log(log_metrics, step=self.iteration) # type: ignore + self.monitor.log(log_metrics, step=self.step_num) # type: ignore return True def sync_weight(self) -> None: diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index e1652c60bf..86232bbc50 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -41,6 +41,7 @@ def __init__(self): def _init_default_config(self): self.default_config = { "_init_config_manager": True, + "mode": "both", "project": "Trinity-RFT", "exp_name": "qwen2.5-1.5B", "monitor_type": MonitorType.WANDB.value, @@ -71,7 +72,7 @@ def _init_default_config(self): "_not_dpo_storage_type": StorageType.QUEUE.value, "storage_type": StorageType.QUEUE.value, "train_dataset_path": "", - "max_retry_times": 3, + "buffer_max_retry_times": 3, "max_retry_interval": 1, "dpo_dataset_train_split": "train", "dpo_dataset_prompt_type": PromptType.MESSAGES.value, @@ -87,31 +88,37 @@ def _init_default_config(self): # Explorer and Sync Configs "engine_type": "vllm_async", "engine_num": 2, - "tensor_parallel_size": 1, + "runner_num": 32, "_grouped_adv_repeat_times": 2, "_not_grouped_adv_repeat_times": 1, "repeat_times": 1, - "_not_dpo_sync_method": SyncMethod.NCCL.value, - "sync_method": SyncMethod.NCCL.value, - "sync_iteration_interval": 10, - "sync_timeout": 1200, - "runner_num": 32, - "max_pending_requests": 32, - "max_waiting_steps": 4, + "eval_interval": 1000, + "tensor_parallel_size": 1, + "enable_prefix_caching": False, + "enforce_eager": True, "dtype": "bfloat16", - "backend": "nccl", "temperature": 1.0, "top_p": 1.0, "top_k": -1, "seed": 42, "logprobs": 0, - "enable_prefix_caching": False, - "enforce_eager": True, + "backend": "nccl", + "use_ray": False, + "gpu_memory_utilization": 0.9, + "enable_chunked_prefill": False, + "max_pending_requests": 32, + "max_waiting_steps": 4, + "max_timeout": 900, + "explorer_max_retry_times": 2, + # Synchronizer Configs + "_not_dpo_sync_method": SyncMethod.NCCL.value, + "sync_method": SyncMethod.NCCL.value, + "sync_interval": 10, + "sync_timeout": 1200, # Trainer Configs "trainer_type": "verl", "algorithm_type": AlgorithmType.PPO.value, - "sft_warmup_iteration": 0, - "eval_interval": 1000, + "sft_warmup_steps": 0, "_nccl_save_interval": 100, "save_interval": 100, # veRL Trainer Configs @@ -121,6 +128,7 @@ def _init_default_config(self): "remove_padding", "dynamic_bsz", ], + "ppo_epochs": 1, "training_strategy": "fsdp", "param_offload": False, "optimizer_offload": False, @@ -155,6 +163,7 @@ def _init_default_config(self): "actor_grad_clip": 1.0, "actor_clip_ratio": 0.2, "actor_entropy_coeff": 0.001, + "_not_dpo_actor_use_kl_loss": True, "actor_use_kl_loss": True, "actor_kl_loss_coef": 0.001, "actor_kl_loss_type": "low_var_kl", @@ -230,10 +239,13 @@ def _set_total_gpu_num(self): self._set_trainer_gpu_num() def _set_trainer_gpu_num(self): - st.session_state["trainer_gpu_num"] = ( - st.session_state["total_gpu_num"] - - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] - ) + if st.session_state["mode"] == "both": + st.session_state["trainer_gpu_num"] = ( + st.session_state["total_gpu_num"] + - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] + ) + else: # model == train + st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"] def _set_max_prompt_tokens(self): st.number_input("Max Prompt Tokens", key="max_prompt_tokens", min_value=1) @@ -246,10 +258,14 @@ def _set_total_epochs(self): @property def _str_for_train_batch_size(self): + trainer_gpu_num_str = ( + "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" + if st.session_state["mode"] == "both" + else "`gpu_per_node * node_num`" + ) return ( f"Please ensure that `train_batch_size` can be divided by " - f"`gpu_per_node * node_num - engine_num * tensor_parallel_size` " - f"= {st.session_state['trainer_gpu_num']}" + f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}." ) def _set_train_batch_size(self): @@ -286,16 +302,22 @@ def _set_dataset_path(self): def _set_dataset_args(self): if st.session_state["dataset_path"] and "://" not in st.session_state["dataset_path"]: subset_name_col, train_split_col, eval_split_col = st.columns(3) - subset_name_col.text_input("Subset Name", key="subset_name") - train_split_col.text_input("Train Split", key="train_split") - eval_split_col.text_input("Eval Split", key="eval_split") + subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", key="subset_name" + ) + train_split_col.text_input( + "Train Split :orange-badge[(Needs review)]", key="train_split" + ) + eval_split_col.text_input("Eval Split :orange-badge[(Needs review)]", key="eval_split") prompt_key_col, response_key_col = st.columns(2) - prompt_key_col.text_input("Prompt Key", key="prompt_key") - response_key_col.text_input("Response Key", key="response_key") + prompt_key_col.text_input("Prompt Key :orange-badge[(Needs review)]", key="prompt_key") + response_key_col.text_input( + "Response Key :orange-badge[(Needs review)]", key="response_key" + ) def _set_default_workflow_type(self): st.selectbox( - "Default Workflow Type", + "Default Workflow Type :orange-badge[(Needs review)]", WORKFLOWS.modules.keys(), key="default_workflow_type", help=r"""`simple_workflow`: call 'model.chat()' to get responses. @@ -308,7 +330,7 @@ def _set_default_workflow_type(self): def _set_default_reward_fn_type(self): st.selectbox( - "Default Reward Fn Type", + "Default Reward Fn Type :orange-badge[(Needs review)]", REWARD_FUNCTIONS.modules.keys(), key="default_reward_fn_type", help=r"""`accuracy_reward`: check the accuracy for math problems. @@ -354,8 +376,8 @@ def _set_train_dataset_path(self): # TODO self.unfinished_fields.add("train_dataset_path") st.warning("Please input train dataset path.") - def _set_max_retry_times(self): - st.number_input("Max Retry Times", key="max_retry_times", min_value=1) + def _set_buffer_max_retry_times(self): + st.number_input("Max Retry Times", key="buffer_max_retry_times", min_value=1) def _set_max_retry_interval(self): st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1) @@ -363,10 +385,10 @@ def _set_max_retry_interval(self): def _set_dpo_dataset_kwargs(self): dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) dpo_dataset_train_split_col.text_input( - "DPO Dataset Train Split", key="dpo_dataset_train_split" + "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" ) dpo_dataset_prompt_type_col.selectbox( - "DPO Dataset Prompt Type", + "DPO Dataset Prompt Type :orange-badge[(Needs review)]", [prompt_type.value for prompt_type in PromptType], key="dpo_dataset_prompt_type", ) @@ -377,22 +399,21 @@ def _set_dpo_dataset_kwargs(self): dpo_dataset_rejected_key_col, ) = st.columns(3) dpo_dataset_prompt_key_col.text_input( - "DPO Dataset Prompt Key", key="dpo_dataset_prompt_key" + "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" ) dpo_dataset_chosen_key_col.text_input( - "DPO Dataset Chosen Key", key="dpo_dataset_chosen_key" + "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" ) dpo_dataset_rejected_key_col.text_input( - "DPO Dataset Rejected Key", key="dpo_dataset_rejected_key" + "DPO Dataset Rejected Key :orange-badge[(Needs review)]", + key="dpo_dataset_rejected_key", ) def _check_sft_warmup_dataset_path(self): - if st.session_state["sft_warmup_iteration"]: + if st.session_state["sft_warmup_steps"]: if not st.session_state["sft_warmup_dataset_path"].strip(): self.unfinished_fields.add("sft_warmup_dataset_path") - st.warning( - "Please input SFT warmup dataset path when `sft_warmup_iteration` is not 0" - ) + st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0") def _set_sft_warmup_dataset_path(self): st.text_input("SFT Warmup Dataset Path", key="sft_warmup_dataset_path") @@ -407,9 +428,12 @@ def _set_sft_warmup_dataset_args(self): sft_warmup_train_split_col, sft_warmup_prompt_type_col, ) = st.columns(2) - sft_warmup_train_split_col.text_input("SFT Train Split", key="sft_warmup_train_split") + sft_warmup_train_split_col.text_input( + "SFT Dataset Train Split :orange-badge[(Needs review)]", + key="sft_warmup_train_split", + ) sft_warmup_prompt_type_col.selectbox( - "SFT Prompt Type", + "SFT Dataset Prompt Type :orange-badge[(Needs review)]", [prompt_type.value for prompt_type in PromptType], key="sft_warmup_prompt_type", ) @@ -419,11 +443,15 @@ def _set_sft_warmup_dataset_args(self): sft_warmup_response_key_col, ) = st.columns(3) sft_warmup_messages_key_col.text_input( - "SFT Messages Key", key="sft_warmup_messages_key" + "SFT Dataset Messages Key :orange-badge[(Needs review)]", + key="sft_warmup_messages_key", + ) + sft_warmup_prompt_key_col.text_input( + "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key" ) - sft_warmup_prompt_key_col.text_input("SFT Prompt Key", key="sft_warmup_prompt_key") sft_warmup_response_key_col.text_input( - "SFT Response Key", key="sft_warmup_response_key" + "SFT Dataset Response Key :orange-badge[(Needs review)]", + key="sft_warmup_response_key", ) def _set_engine_type(self): @@ -531,19 +559,19 @@ def on_change(): "Sync Method", [sync_method.value for sync_method in SyncMethod], key="sync_method", - help="""`nccl`: the explorer and trainer sync model weights once every `sync_iteration_interval` steps. + help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps. -`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_iteration_interval`.""", +`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""", disabled=disabled, on_change=on_change, ) - def _set_sync_iteration_interval(self): + def _set_sync_interval(self): st.number_input( - "Sync Iteration Interval", - key="sync_iteration_interval", + "Sync Interval", + key="sync_interval", min_value=1, - help="""The iteration interval at which the `explorer` and `trainer` synchronize model weight.""", + help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""", ) def _set_sync_timeout(self): @@ -591,15 +619,49 @@ def _set_logprobs(self): st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) def _set_enable_prefix_caching(self): - st.checkbox("Enable Prefix Caching", key="enable_prefix_caching") + st.checkbox("Prefix Caching", key="enable_prefix_caching") def _set_enforce_eager(self): st.checkbox("Enforce Eager", key="enforce_eager") + def _set_use_ray(self): + st.checkbox("Use Ray", key="use_ray") + + def _set_gpu_memory_utilization(self): + st.number_input( + "GPU Memory Utilization", key="gpu_memory_utilization", min_value=0.0, max_value=1.0 + ) + + def _set_enable_chunked_prefill(self): + st.checkbox("Chunked Prefill", key="enable_chunked_prefill") + + def _set_max_timeout(self): + st.number_input("Max Timeout", key="max_timeout", min_value=0) + + def _set_explorer_max_retry_times(self): + st.number_input("Explorer Max Retry Times", key="explorer_max_retry_times", min_value=0) + def _set_trainer_type(self): st.selectbox("Trainer Type", ["verl"], key="trainer_type") def _set_algorithm_type(self): + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value + elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["mode"] = "train" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: + st.session_state["mode"] = "both" + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + else: # TODO: add more algorithms + pass + self._set_trainer_gpu_num() + st.selectbox( "Algorithm Type", [ @@ -609,11 +671,11 @@ def _set_algorithm_type(self): AlgorithmType.OPMD.value, ], key="algorithm_type", - on_change=self._set_adv_estimator, + on_change=on_change, ) - def _set_sft_warmup_iteration(self): - st.number_input("SFT Warmup Iteration", key="sft_warmup_iteration", min_value=0) + def _set_sft_warmup_steps(self): + st.number_input("SFT Warmup Steps", key="sft_warmup_steps", min_value=0) def _set_eval_interval(self): st.number_input("Eval Interval", key="eval_interval", min_value=1) @@ -631,26 +693,35 @@ def _set_training_args(self): ) def _set_save_interval(self): - if st.session_state["sync_method"] == SyncMethod.NCCL.value: + if ( + st.session_state["algorithm_type"] == AlgorithmType.DPO.value + or st.session_state["sync_method"] == SyncMethod.NCCL.value + ): st.session_state["save_interval"] = st.session_state["_nccl_save_interval"] freeze_save_interval = False else: - st.session_state["save_interval"] = st.session_state["sync_iteration_interval"] + st.session_state["save_interval"] = st.session_state["sync_interval"] freeze_save_interval = True def on_change(): - if st.session_state["sync_method"] == SyncMethod.NCCL.value: + if ( + st.session_state["algorithm_type"] == AlgorithmType.DPO.value + or st.session_state["sync_method"] == SyncMethod.NCCL.value + ): st.session_state["_nccl_save_interval"] = st.session_state["save_interval"] st.number_input( "Save Interval", key="save_interval", min_value=1, - help="Set to `sync_iteration_interval` when `sync_method` is `checkpoint`", + help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`", disabled=freeze_save_interval, on_change=on_change, ) + def _set_ppo_epochs(self): + st.number_input("PPO Epochs", key="ppo_epochs", min_value=1) + def _set_training_strategy(self): st.selectbox( "Training Strategy", @@ -679,7 +750,7 @@ def _set_resume_from_path(self): st.warning("Please input a valid resume path when `resume_mode == resume_path`") def _set_critic_warmup(self): - st.number_input("Critic Warmup Iteration", key="critic_warmup", min_value=0) + st.number_input("Critic Warmup Steps", key="critic_warmup", min_value=0) def _set_total_training_steps(self): st.number_input("Total Training Steps", key="total_training_steps", min_value=1) @@ -700,22 +771,10 @@ def _set_max_critic_ckpt_to_keep(self): st.number_input("Max Critic Checkpoint to Keep", key="max_critic_ckpt_to_keep", min_value=1) def _set_gamma(self): - st.number_input("Gamma", key="gamma") + st.number_input(r"Gamma :blue-badge[$\gamma$]", key="gamma") def _set_lam(self): - st.number_input("Lambda", key="lam") - - def _set_adv_estimator(self): - if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: - st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value - elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - else: # TODO: add more algorithms - pass + st.number_input(r"Lambda :blue-badge[$\lambda$]", key="lam") def _set_norm_adv_by_std_in_grpo(self): st.checkbox("Norm Adv by Std in GRPO", key="norm_adv_by_std_in_grpo") @@ -744,7 +803,7 @@ def _set_actor_ppo_micro_batch_size_per_gpu(self): st.session_state["_train_batch_size_per_gpu"], ) st.number_input( - "Micro Batch Size Per GPU for Actor", + "Micro Batch Size Per GPU :blue-badge[(Actor)]", key="actor_ppo_micro_batch_size_per_gpu", min_value=1, max_value=st.session_state["_train_batch_size_per_gpu"], @@ -756,7 +815,7 @@ def _set_ref_log_prob_micro_batch_size_per_gpu(self): st.session_state["_train_batch_size_per_gpu"], ) st.number_input( - "Micro Batch Size Per GPU for Ref", + "Micro Batch Size Per GPU :blue-badge[(Ref)]", key="ref_log_prob_micro_batch_size_per_gpu", min_value=1, max_value=st.session_state["_train_batch_size_per_gpu"], @@ -772,7 +831,7 @@ def _set_actor_ulysses_sequence_parallel_size(self): def _set_actor_lr(self): st.number_input( - "Learning Rate for Actor", + "Learning Rate :blue-badge[(Actor)]", key="actor_lr", min_value=1e-7, max_value=1e-3, @@ -781,24 +840,35 @@ def _set_actor_lr(self): def _set_actor_warmup_style(self): st.selectbox( - "LR Warmup Style for Actor", + "LR Warmup Style :blue-badge[(Actor)]", ["constant", "cosine"], key="actor_warmup_style", ) def _set_actor_lr_warmup_steps_ratio(self): st.number_input( - "LR Warmup Steps Ratio for Actor", + "LR Warmup Steps Ratio :blue-badge[(Actor)]", key="actor_lr_warmup_steps_ratio", min_value=0.0, max_value=1.0, ) def _set_actor_grad_clip(self): - st.number_input("Grad Clip", key="actor_grad_clip", min_value=0.0, max_value=1.0) + st.number_input( + "Grad Clip :blue-badge[(Actor)]", + key="actor_grad_clip", + min_value=0.0, + max_value=1.0, + help="Clipping by Norm", + ) def _set_actor_clip_ratio(self): - st.number_input("Clip Ratio", key="actor_clip_ratio", min_value=0.0, max_value=1.0) + st.number_input( + r"Clip Ratio :blue-badge[$\epsilon$]", + key="actor_clip_ratio", + min_value=0.0, + max_value=1.0, + ) def _set_actor_entropy_coeff(self): st.number_input( @@ -810,11 +880,21 @@ def _set_actor_entropy_coeff(self): ) def _set_actor_use_kl_loss(self): - st.checkbox("Use KL Loss", key="actor_use_kl_loss") + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["actor_use_kl_loss"] = True + else: + st.session_state["actor_use_kl_loss"] = st.session_state["_not_dpo_actor_use_kl_loss"] + + def on_change(): + st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[ + "actor_use_kl_loss" + ] + + st.checkbox("Use KL Loss", key="actor_use_kl_loss", on_change=on_change) def _set_actor_kl_loss_coef(self): st.number_input( - "KL Loss Coef", + r"KL Loss Coef :blue-badge[$\beta$]", key="actor_kl_loss_coef", min_value=0.0, max_value=1.0, @@ -859,7 +939,7 @@ def _set_critic_ppo_micro_batch_size_per_gpu(self): st.session_state["_train_batch_size_per_gpu"], ) st.number_input( - "Micro Batch Size Per GPU for Critic", + "Micro Batch Size Per GPU :blue-badge[(Critic)]", key="critic_ppo_micro_batch_size_per_gpu", min_value=1, max_value=st.session_state["_train_batch_size_per_gpu"], @@ -875,7 +955,7 @@ def _set_critic_ulysses_sequence_parallel_size(self): def _set_critic_lr(self): st.number_input( - "Learning Rate for Critic", + "Learning Rate :blue-badge[(Critic)]", key="critic_lr", min_value=1e-7, max_value=1e-3, @@ -884,14 +964,14 @@ def _set_critic_lr(self): def _set_critic_warmup_style(self): st.selectbox( - "LR Warmup Style for Critic", + "LR Warmup Style :blue-badge[(Critic)]", ["constant", "cosine"], key="critic_warmup_style", ) def _set_critic_lr_warmup_steps_ratio(self): st.number_input( - "LR Warmup Steps Ratio for Critic", + "LR Warmup Steps Ratio :blue-badge[(Critic)]", key="critic_lr_warmup_steps_ratio", min_value=0.0, max_value=1.0, @@ -899,10 +979,11 @@ def _set_critic_lr_warmup_steps_ratio(self): def _set_critic_grad_clip(self): st.number_input( - "Grad Clip for Critic", + "Grad Clip :blue-badge[(Critic)]", key="critic_grad_clip", min_value=0.0, max_value=1.0, + help="Clipping by Norm", ) def _set_critic_cliprange_value(self): @@ -940,38 +1021,42 @@ def beginner_mode(self): self._set_dataset_path() - self._set_configs_with_st_columns( - ["algorithm_type", "sft_warmup_iteration", "monitor_type"] - ) - if st.session_state["sft_warmup_iteration"] > 0: + self._set_configs_with_st_columns(["algorithm_type", "sft_warmup_steps", "monitor_type"]) + if st.session_state["sft_warmup_steps"] > 0: self._set_sft_warmup_dataset_path() st.header("Important Configs") self._set_configs_with_st_columns( ["node_num", "gpu_per_node", "engine_num", "tensor_parallel_size"] + if st.session_state["mode"] == "both" + else ["node_num", "gpu_per_node"] ) self._check_engine_num_and_tp_size() self._set_configs_with_st_columns( - ["total_epochs", "train_batch_size", "max_prompt_tokens", "max_response_tokens"] + ["total_epochs", "train_batch_size", "ppo_epochs", "repeat_times"] + if st.session_state["mode"] == "both" + else ["total_epochs", "train_batch_size", "ppo_epochs"] ) self._check_train_batch_size() + self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) + + self._set_configs_with_st_columns( + ["sync_interval", "eval_interval", "save_interval"] + if st.session_state["mode"] == "both" + else ["eval_interval", "save_interval"] + ) + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: self._set_dataset_args() else: self._set_dpo_dataset_kwargs() - if st.session_state["sft_warmup_iteration"] > 0: + if st.session_state["sft_warmup_steps"] > 0: self._set_sft_warmup_dataset_args() - self._set_configs_with_st_columns( - ["default_workflow_type", "default_reward_fn_type", "repeat_times"] - ) - - self._set_configs_with_st_columns( - ["sync_iteration_interval", "eval_interval", "save_interval"] - ) + self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) self._set_actor_use_kl_loss() if st.session_state["actor_use_kl_loss"]: @@ -1003,7 +1088,7 @@ def _expert_model_part(self): self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) def _expert_buffer_part(self): - self._set_configs_with_st_columns(["total_epochs", "train_batch_size"]) + self._set_configs_with_st_columns(["total_epochs", "train_batch_size", "storage_type"]) self._check_train_batch_size() self._set_dataset_path() @@ -1013,39 +1098,45 @@ def _expert_buffer_part(self): else: self._set_dpo_dataset_kwargs() - self._set_configs_with_st_columns( - ["default_workflow_type", "default_reward_fn_type", "storage_type"] - ) + self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) self.buffer_advanced_tab = st.expander("Advanced Config") with self.buffer_advanced_tab: - self._set_configs_with_st_columns(["max_retry_times", "max_retry_interval"]) + self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"]) self._set_sft_warmup_dataset_path() self._set_sft_warmup_dataset_args() - def _expert_connector_part(self): + def _expert_explorer_part(self): self._set_configs_with_st_columns( ["engine_type", "engine_num", "tensor_parallel_size", "repeat_times"] ) self._check_engine_num_and_tp_size() - self._set_configs_with_st_columns( - ["sync_method", "sync_iteration_interval", "sync_timeout"] - ) + self._set_configs_with_st_columns(["sync_method", "sync_interval", "sync_timeout"]) with st.expander("Advanced Config"): self._set_configs_with_st_columns( - ["runner_num", "max_pending_requests", "max_waiting_steps", "dtype"] + ["runner_num", "temperature", "top_p", "top_k", "seed", "logprobs"] ) - self._set_configs_with_st_columns(["backend", "temperature", "seed", "logprobs"]) + self._set_configs_with_st_columns(["dtype", "backend", "gpu_memory_utilization"]) + self._set_configs_with_st_columns( + [ + "max_pending_requests", + "max_waiting_steps", + "max_timeout", + "explorer_max_retry_times", + ] + ) - self._set_configs_with_st_columns(["enable_prefix_caching", "enforce_eager"]) + self._set_configs_with_st_columns( + ["enable_prefix_caching", "enforce_eager", "use_ray", "enable_chunked_prefill"] + ) def _expert_trainer_part(self): self._set_configs_with_st_columns( # TODO: may add `trainer_type` - ["algorithm_type", "sft_warmup_iteration", "eval_interval", "save_interval"] + ["algorithm_type", "sft_warmup_steps", "eval_interval", "save_interval"] ) self._check_sft_warmup_dataset_path() @@ -1065,7 +1156,7 @@ def _expert_verl_trainer_part(self): st.subheader("RL Training Config") self._set_training_args() - self._set_configs_with_st_columns(["training_strategy", "resume_mode"]) + self._set_configs_with_st_columns(["ppo_epochs", "training_strategy", "resume_mode"]) if st.session_state["training_strategy"] == "fsdp": self._set_configs_with_st_columns(["param_offload", "optimizer_offload"]) @@ -1134,20 +1225,18 @@ def _expert_verl_trainer_part(self): self._set_critic_checkpoint() def expert_mode(self): - model_tab, buffer_tab, connector_tab, trainer_tab = st.tabs( - ["Model", "Data", "Explorer and Synchronizer", "Trainer"] - ) - with model_tab: - self._expert_model_part() - - with buffer_tab: - self._expert_buffer_part() - - with connector_tab: - self._expert_connector_part() - - with trainer_tab: - self._expert_trainer_part() + tab2func = { + "Model": self._expert_model_part, + "Data": self._expert_buffer_part, + "Explorer and Synchronizer": self._expert_explorer_part, + "Trainer": self._expert_trainer_part, + } + if st.session_state["mode"] == "train": + del tab2func["Explorer and Synchronizer"] + tabs = st.tabs(list(tab2func.keys())) + for tab, func in zip(tabs, tab2func.values()): + with tab: + func() def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node: int = 8): balance_batch = "balance_batch" in st.session_state["training_args"] @@ -1167,7 +1256,6 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node else: fsdp_config = {} - ppo_epochs = 1 # TODO ppo_max_token_len_per_gpu = st.session_state["repeat_times"] * ( st.session_state["max_prompt_tokens"] + st.session_state["max_response_tokens"] ) @@ -1218,7 +1306,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "use_kl_loss": st.session_state["actor_use_kl_loss"], "kl_loss_coef": st.session_state["actor_kl_loss_coef"], "kl_loss_type": st.session_state["actor_kl_loss_type"], - "ppo_epochs": ppo_epochs, + "ppo_epochs": st.session_state["ppo_epochs"], "shuffle": False, "ulysses_sequence_parallel_size": st.session_state[ "actor_ulysses_sequence_parallel_size" @@ -1307,7 +1395,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "ulysses_sequence_parallel_size": st.session_state[ "critic_ulysses_sequence_parallel_size" ], - "ppo_epochs": ppo_epochs, + "ppo_epochs": st.session_state["ppo_epochs"], "shuffle": False, "grad_clip": st.session_state["critic_grad_clip"], "cliprange_value": st.session_state["critic_cliprange_value"], @@ -1362,7 +1450,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "del_local_ckpt_after_load": st.session_state["del_local_ckpt_after_load"], "default_local_dir": st.session_state["checkpoint_path"], "val_before_train": False, - "sync_freq": st.session_state["sync_iteration_interval"], + "sync_freq": st.session_state["sync_interval"], "max_actor_ckpt_to_keep": st.session_state["max_actor_ckpt_to_keep"], "max_critic_ckpt_to_keep": st.session_state["max_critic_ckpt_to_keep"], }, @@ -1370,13 +1458,16 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node return trainer_config def generate_config(self): - trainer_nnodes = ( - st.session_state["node_num"] - - st.session_state["engine_num"] - * st.session_state["tensor_parallel_size"] - // st.session_state["gpu_per_node"] - ) - if st.session_state["node_num"] == 1: + if st.session_state["mode"] == "both": + trainer_nnodes = ( + st.session_state["node_num"] + - st.session_state["engine_num"] + * st.session_state["tensor_parallel_size"] + // st.session_state["gpu_per_node"] + ) + else: + trainer_nnodes = st.session_state["node_num"] + if st.session_state["node_num"] == 1 and st.session_state["mode"] == "both": trainer_n_gpus_per_node = ( st.session_state["gpu_per_node"] - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] @@ -1384,6 +1475,12 @@ def generate_config(self): else: trainer_n_gpus_per_node = st.session_state["gpu_per_node"] + critic_model_path = ( + st.session_state["critic_model_path"].strip() + if st.session_state["critic_model_path"].strip() + else st.session_state["model_path"] + ) + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: train_dataset_path = ( st.session_state["train_dataset_path"].strip() @@ -1421,6 +1518,7 @@ def generate_config(self): help=help_messages, ): config = { + "mode": st.session_state["mode"], "data": { "total_epochs": st.session_state["total_epochs"], "batch_size": st.session_state["train_batch_size"], @@ -1436,6 +1534,7 @@ def generate_config(self): }, "model": { "model_path": st.session_state["model_path"], + "critic_model_path": critic_model_path, "max_prompt_tokens": st.session_state["max_prompt_tokens"], "max_response_tokens": st.session_state["max_response_tokens"], "checkpoint_path": st.session_state["checkpoint_path"], @@ -1445,18 +1544,16 @@ def generate_config(self): "gpu_per_node": st.session_state["gpu_per_node"], }, "buffer": { - "max_retry_times": st.session_state["max_retry_times"], + "max_retry_times": st.session_state["buffer_max_retry_times"], "max_retry_interval": st.session_state["max_retry_interval"], "train_dataset": { "name": "experience_buffer", # TODO "storage_type": st.session_state["storage_type"], - "algorithm_type": st.session_state["algorithm_type"], "path": train_dataset_path, }, "sft_warmup_dataset": { "name": "sft_warmup_dataset", "storage_type": sft_storage_type, - "algorithm_type": AlgorithmType.SFT.value, "path": st.session_state["sft_warmup_dataset_path"], "kwargs": { "train_split": st.session_state["sft_warmup_train_split"], @@ -1471,28 +1568,38 @@ def generate_config(self): "engine_type": st.session_state["engine_type"], "engine_num": st.session_state["engine_num"], "runner_num": st.session_state["runner_num"], + "repeat_times": st.session_state["repeat_times"], + # "chat_template": None, # TODO: add chat template + "eval_interval": st.session_state["eval_interval"], "tensor_parallel_size": st.session_state["tensor_parallel_size"], "enable_prefix_caching": st.session_state["enable_prefix_caching"], "enforce_eager": st.session_state["enforce_eager"], "dtype": st.session_state["dtype"], "temperature": st.session_state["temperature"], + "top_p": st.session_state["top_p"], # TODO + "top_k": st.session_state["top_k"], # TODO "seed": st.session_state["seed"], "logprobs": st.session_state["logprobs"], - "repeat_times": st.session_state["repeat_times"], "backend": st.session_state["backend"], + "use_ray": st.session_state["use_ray"], # TODO + "gpu_memory_utilization": st.session_state["gpu_memory_utilization"], # TODO + "enable_chunked_prefill": st.session_state["enable_chunked_prefill"], # TODO + "use_v1": True, "max_pending_requests": st.session_state["max_pending_requests"], "max_waiting_steps": st.session_state["max_waiting_steps"], + "max_timeout": st.session_state["max_timeout"], # TODO + "max_retry_times": st.session_state["explorer_max_retry_times"], # TODO }, "synchronizer": { "sync_method": st.session_state["sync_method"], - "sync_iteration_interval": st.session_state["sync_iteration_interval"], + "sync_interval": st.session_state["sync_interval"], "sync_timeout": st.session_state["sync_timeout"], }, "trainer": { "trainer_type": st.session_state["trainer_type"], "algorithm_type": st.session_state["algorithm_type"], "trainer_config": trainer_config, - "sft_warmup_iteration": st.session_state["sft_warmup_iteration"], + "sft_warmup_steps": st.session_state["sft_warmup_steps"], "eval_interval": st.session_state["eval_interval"], "save_interval": st.session_state["save_interval"], }, diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py index 0334fdc7ea..8d49c4cb1a 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/manager.py @@ -34,10 +34,10 @@ def _check_config_consistency(self, config: Config) -> None: f"The current config is inconsistent with the backup config in {backup_config_path}." ) - def save_explorer(self, current_task_index: int, current_iteration: int) -> None: + def save_explorer(self, current_task_index: int, current_step: int) -> None: with open(self.explorer_meta_path, "w", encoding="utf-8") as f: json.dump( - {"latest_task_index": current_task_index, "latest_iteration": current_iteration}, + {"latest_task_index": current_task_index, "latest_iteration": current_step}, f, indent=2, ) @@ -53,9 +53,9 @@ def load_explorer(self) -> dict: logger.error(f"Failed to load explore meta file: {e}") return {} - def save_trainer(self, current_iteration: int) -> None: + def save_trainer(self, current_step: int) -> None: with open(self.trainer_meta_path, "w", encoding="utf-8") as f: - json.dump({"latest_iteration": current_iteration}, f, indent=2) + json.dump({"latest_iteration": current_step}, f, indent=2) def load_trainer(self) -> dict: if os.path.exists(self.trainer_meta_path): diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 4749a57b5c..1ea242ea10 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -34,7 +34,7 @@ def __init__(self, config: Config) -> None: self.config.buffer.sft_warmup_dataset, # type: ignore self.config.buffer, ) - if self.config.trainer.sft_warmup_iteration > 0 + if self.config.trainer.sft_warmup_steps > 0 else None ) self.engine = get_trainer_wrapper(config) @@ -46,24 +46,24 @@ def prepare(self) -> None: def train(self, algo_type: AlgorithmType = AlgorithmType.PPO): """Train the model.""" while True: - train_status, _ = self.train_iteration(algo_type) + train_status, _ = self.train_step(algo_type) if not train_status: break - def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: - """Train one step. Each step contains `sync_iteration_interval` iteration. + def train_one_period(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: + """Train for one period. Each period contains `sync_interval` steps. Returns: train_status: Whether to continue training. - train_iter_num: The number of training iterations""" - for _ in range(self.config.synchronizer.sync_iteration_interval): - train_status, train_iter_num = self.train_iteration(algo_type) + train_step_num: The number of training steps""" + for _ in range(self.config.synchronizer.sync_interval): + train_status, train_step_num = self.train_step(algo_type) if not train_status: - return False, train_iter_num - self.logger.info(f"Trainer iteration {train_iter_num} finished.") - return True, train_iter_num + return False, train_step_num + self.logger.info(f"Trainer steps {train_step_num} finished.") + return True, train_step_num - def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: - """Train one iteration. + def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: + """Train one step. Args: algo_type (AlgorithmType): The type of data to be used for training. @@ -75,7 +75,7 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple self.engine.set_mode(algo_type) if algo_type.is_sft(): exps = self.sft_warmup_buffer.read() - return self.engine.train_sft_iteration( + return self.engine.train_sft_step( Experiences.gather_experiences( exps, pad_token_id=self.config.buffer.pad_token_id, # type: ignore @@ -90,8 +90,8 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple exps = self.train_buffer.read(strategy=strategy) except StopIteration: self.logger.warning("No more data to train. Stop training.") - return False, 0 # TODO: get the actual iteration number - return self.engine.train_rft_iteration( + return False, 0 # TODO: get the actual step number + return self.engine.train_rft_step( Experiences.gather_experiences( exps, pad_token_id=self.config.buffer.pad_token_id, # type: ignore @@ -99,7 +99,7 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple ) elif algo_type.is_dpo(): exps = self.train_buffer.read() - return self.engine.train_dpo_iteration( + return self.engine.train_dpo_step( Experiences.gather_dpo_experiences( exps, pad_token_id=self.config.buffer.pad_token_id, # type: ignore @@ -126,15 +126,15 @@ def prepare(self) -> None: """Do some preparation before training started.""" @abstractmethod - def train_rft_iteration(self, experiences) -> Tuple[bool, int]: + def train_rft_step(self, experiences) -> Tuple[bool, int]: """Train on the RFT data.""" @abstractmethod - def train_sft_iteration(self, experiences) -> Tuple[bool, int]: + def train_sft_step(self, experiences) -> Tuple[bool, int]: """Train on the SFT data.""" @abstractmethod - def train_dpo_iteration(self, experiences) -> Tuple[bool, int]: + def train_dpo_step(self, experiences) -> Tuple[bool, int]: """Train on the DPO data.""" @abstractmethod diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index ffb1d2be12..1f4f9ddfb8 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -118,7 +118,6 @@ def __init__( resource_pool_spec=resource_pool_spec, mapping=mapping ) - self.sft_iter_num = 0 super().__init__( config, tokenizer, @@ -147,9 +146,11 @@ def prepare(self): self.actor_rollout_wg.setup_weight_sync_group() self.global_steps = 0 + self.sft_warmup_step_num = 0 # load checkpoint before doing anything self._load_checkpoint() + self.sft_warmup_step_num = min(self.global_steps, self.config.trainer.sft_warmup_steps) # perform validation before training # currently, we only support validation using the reward_function. @@ -183,7 +184,7 @@ def _create_dataloader(self): # else: self.total_training_steps = float("inf") - def train_dpo_iteration(self, experiences: Experiences) -> Tuple[bool, int]: + def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]: metrics = {} timing_raw = {} @@ -246,7 +247,9 @@ def train_dpo_iteration(self, experiences: Experiences) -> Tuple[bool, int]: self.global_steps += 1 return True, self.global_steps - 1 - def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: + def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]: + if self.sft_warmup_step_num >= self.config.trainer.sft_warmup_steps: + return False, self.global_steps - 1 metrics = {} timing_raw = {} @@ -296,24 +299,20 @@ def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: log as sft metrics - self.sft_iter_num += 1 self.logger.log(data=metrics, step=self.global_steps) - # print(f'{self.sft_iter_num=}, {self.config.synchronizer.sync_iteration_interval * self.config.trainer.sft_warmup_iteration=}') - if ( - self.sft_iter_num - == self.config.synchronizer.sync_iteration_interval - * self.config.trainer.sft_warmup_iteration - ): + self.sft_warmup_step_num += 1 + self.global_steps += 1 + if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps: self.logger.log( - data={"sft_warmup_iteration": self.sft_iter_num}, + data={"sft_warmup_steps": self.sft_warmup_step_num}, step=self.global_steps, ) with _timer("save_checkpoint", timing_raw): self._save_checkpoint() - self.global_steps += 1 + return False, self.global_steps - 1 return True, self.global_steps - 1 - def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: + def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: metrics = {} timing_raw = {}