diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index fb9a82873b..d7528fa0df 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -42,7 +42,7 @@ data: # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' # downstream loading related - total_epoch: 1 + total_epochs: 1 batch_size: 96 default_workflow_type: 'math_workflow' ``` @@ -53,7 +53,7 @@ Here you can set the basic information for the GSM-8K dataset, database informat + `dataset_config`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library. + `format_config`: some dataset format config items, which are used to map original data field names to unified ones. + `db_url`: the URL of the postgresql database to store the result dataset. -+ `total_epoch`: the total number of epochs to train on this dataset. ++ `total_epochs`: the total number of epochs to train on this dataset. + `batch_size`: the training batch size. + `default_workflow_type`: the default exploring workflow type. Please refer to [programming guide](trinity_programming_guide.md) for more details. @@ -74,7 +74,7 @@ data: # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' # downstream loading related - total_epoch: 1 + total_epochs: 1 batch_size: 96 default_workflow_type: 'math_workflow' @@ -120,7 +120,7 @@ data: # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' # downstream loading related - total_epoch: 1 + total_epochs: 1 batch_size: 96 default_workflow_type: 'math_workflow' @@ -199,7 +199,7 @@ data: # database related. The result dataset will be stored in the database. db_url: 'postgresql://{user_name}@localhost:5432/{db_name}' # downstream loading related - total_epoch: 20 + total_epochs: 20 batch_size: 32 default_workflow_type: 'math_workflow' ``` diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 0ef8e93db5..0be1153630 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -3,6 +3,18 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example. +## Monitor + +```yaml +monitor: + project: "Trinity-RFT-countdown" + name: "qwen2.5-1.5B-countdown" +``` + +- `monitor.project`: The project name. It must be set manually. +- `monitor.name`: The name of the experiment. It must be set manually. + + ## Monitor ```yaml @@ -33,7 +45,7 @@ data: max_retry_times: 3 max_retry_interval: 1 - total_epoch: 20 + total_epochs: 20 batch_size: 96 default_workflow_type: 'math_workflow' default_reward_fn_type: 'countdown_reward' @@ -47,7 +59,7 @@ data: - `data.db_url`: The URL of the database. - `data.max_retry_times`: The maximum number of retries when loading the dataset from database. - `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database. -- `data.total_epoch`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually. +- `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually. - `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `actor_rollout_ref.rollout.n` Default is `1`. It should be set manually. - `data.default_workflow_type`: The default workflow type used for training. - `data.default_reward_fn_type`: The default reward function type used for training. @@ -345,10 +357,14 @@ algorithm: gamma: 1.0 lam: 1.0 adv_estimator: gae + norm_adv_by_std_in_grpo: True + use_kl_in_reward: False kl_penalty: kl # how to estimate kl divergence kl_ctrl: type: fixed kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 trainer: balance_batch: True @@ -363,7 +379,7 @@ trainer: save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False + resume_from_path: "" test_freq: 100 critic_warmup: 0 default_hdfs_dir: null @@ -383,8 +399,9 @@ trainer: - `actor_rollout_ref.actor.grad_clip`: Gradient clip for actor model training. - `actor_rollout_ref.actor.clip_ratio`: Used for compute policy loss. - `actor_rollout_ref.actor.entropy_coeff`: Used for compute policy loss. -- `actor_rollout_ref.actor.use_kl_loss`: True for GRPO. -- `actor_rollout_ref.actor.kl_loss_coef`: Used for GRPO, optional value is `kl`, `abs`, `mse` or `low_var_kl`. +- `actor_rollout_ref.actor.use_kl_loss`: Whether to enable kl loss. +- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. +- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`. - `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size. - `actor_rollout_ref.actor.alg_type`: Used for OPMD, optional value is `ppo`, `opmd` or `pairwise_opmd`. - `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy. diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 991c662138..29c49bb1a0 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -1,6 +1,6 @@ mode: train data: - total_epoch: 20 + total_epochs: 20 batch_size: 32 # NOTE train_split: "train" dataset_path: '' @@ -22,7 +22,6 @@ buffer: train_dataset: name: dpo_buffer storage_type: file - algorithm_type: dpo path: '/PATH/TO/DATASET/' kwargs: prompt_type: plaintext # plaintext/messages diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml index ae7689106b..4da9a7ddb5 100644 --- a/examples/dpo_humanlike/train_dpo.yaml +++ b/examples/dpo_humanlike/train_dpo.yaml @@ -173,7 +173,6 @@ trainer: save_freq: 30 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 5 critic_warmup: 0 default_hdfs_dir: null diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 100420acd8..0258bfdbf0 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -1,5 +1,5 @@ data: - total_epoch: 20 + total_epochs: 20 batch_size: 4 dataset_path: 'scripts/data_prepare/alfworld_data' default_workflow_type: 'alfworld_workflow' @@ -21,7 +21,6 @@ buffer: train_dataset: name: alfworld_buffer storage_type: queue - algorithm_type: ppo path: 'sqlite:///alfworld.db' explorer: engine_type: vllm_async diff --git a/examples/grpo_alfworld/train_alfworld.yaml b/examples/grpo_alfworld/train_alfworld.yaml index 0a1c109754..88f151bdcb 100644 --- a/examples/grpo_alfworld/train_alfworld.yaml +++ b/examples/grpo_alfworld/train_alfworld.yaml @@ -172,7 +172,6 @@ trainer: save_freq: 1 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 100 critic_warmup: 0 default_hdfs_dir: null diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 677027ea19..63850d5d24 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -18,7 +18,7 @@ data: # db related db_url: '' # downstream loading related - total_epoch: 1 + total_epochs: 1 batch_size: 96 default_workflow_type: 'math_workflow' model: @@ -35,12 +35,10 @@ buffer: train_dataset: name: gsm8k_buffer storage_type: queue - algorithm_type: ppo path: 'sqlite:///gsm8k.db' # sft_warmup_dataset: # Uncomment these to enable sft warmup # name: warmup_data # storage_type: file - # algorithm_type: sft # path: '/PATH/TO/WARMUP_DATA/' # kwargs: # prompt_type: plaintext diff --git a/examples/grpo_gsm8k/train_gsm8k.yaml b/examples/grpo_gsm8k/train_gsm8k.yaml index eeb64dc746..2e8365c6cb 100644 --- a/examples/grpo_gsm8k/train_gsm8k.yaml +++ b/examples/grpo_gsm8k/train_gsm8k.yaml @@ -177,7 +177,6 @@ trainer: save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 5 critic_warmup: 0 default_hdfs_dir: null diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index ae7e40d9d0..07ea448548 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -10,7 +10,7 @@ data: # db related db_url: '' # downstream loading related - total_epoch: 20 + total_epochs: 20 batch_size: 288 default_workflow_type: 'math_workflow' model: @@ -27,8 +27,7 @@ buffer: train_dataset: name: math_buffer storage_type: queue - algorithm_type: ppo - path: 'sqlite:////math.db' + path: 'sqlite:///math.db' explorer: engine_type: vllm_async engine_num: 2 diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml index c49fda58ce..937c19657e 100644 --- a/examples/grpo_math/train_math.yaml +++ b/examples/grpo_math/train_math.yaml @@ -169,7 +169,6 @@ trainer: save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 5 critic_warmup: 0 default_hdfs_dir: null diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index d036fc54fc..25b5dfa073 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -1,5 +1,5 @@ data: - total_epoch: 20 + total_epochs: 20 batch_size: 4 dataset_path: 'scripts/data_prepare/sciworld_data' default_workflow_type: 'sciworld_workflow' @@ -21,7 +21,6 @@ buffer: train_dataset: name: sciworld_buffer storage_type: queue - algorithm_type: ppo path: 'sqlite:///sciworld.db' explorer: engine_type: vllm_async diff --git a/examples/grpo_sciworld/train_sciworld.yaml b/examples/grpo_sciworld/train_sciworld.yaml index 880fc61fcc..330b659afb 100644 --- a/examples/grpo_sciworld/train_sciworld.yaml +++ b/examples/grpo_sciworld/train_sciworld.yaml @@ -167,7 +167,6 @@ trainer: save_freq: 1 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 100 critic_warmup: 0 default_hdfs_dir: null diff --git a/examples/grpo_webshop/train_webshop.yaml b/examples/grpo_webshop/train_webshop.yaml index aac06b7043..0ae8675f50 100644 --- a/examples/grpo_webshop/train_webshop.yaml +++ b/examples/grpo_webshop/train_webshop.yaml @@ -172,7 +172,6 @@ trainer: save_freq: 1 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 100 critic_warmup: 0 default_hdfs_dir: null diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml index ff5c2c57bf..eb9916018d 100644 --- a/examples/grpo_webshop/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -1,5 +1,5 @@ data: - total_epoch: 20 + total_epochs: 20 batch_size: 4 dataset_path: 'scripts/data_prepare/webshop_data' default_workflow_type: 'webshop_workflow' @@ -21,7 +21,6 @@ buffer: train_dataset: name: webshop_buffer storage_type: queue - algorithm_type: ppo path: 'sqlite:///webshop.db' explorer: engine_type: vllm_async diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml index 5a458b24e8..6cde601158 100644 --- a/examples/opmd_gsm8k/opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -1,5 +1,5 @@ data: - total_epoch: 1 + total_epochs: 1 batch_size: 96 dataset_path: '{path to datasets}/gsm8k' default_workflow_type: 'math_workflow' @@ -20,7 +20,6 @@ buffer: train_dataset: name: gsm8k_buffer storage_type: queue - algorithm_type: opmd path: 'sqlite:///gsm8k_opmd.db' explorer: engine_type: vllm_async diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml index a82f4d0739..97384e57c3 100644 --- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml @@ -204,7 +204,6 @@ trainer: save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 100 critic_warmup: 0 default_hdfs_dir: null diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml index a531eebda0..9282a7d1a0 100644 --- a/examples/ppo_countdown/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -1,5 +1,5 @@ data: - total_epoch: 20 + total_epochs: 20 batch_size: 96 dataset_path: 'countdown_dataset/oneshot-split' default_workflow_type: 'math_workflow' @@ -23,8 +23,7 @@ buffer: train_dataset: name: countdown_buffer storage_type: queue - algorithm_type: ppo - path: 'sqlite:////countdown.db' + path: 'sqlite:///countdown.db' explorer: engine_type: vllm_async engine_num: 2 diff --git a/examples/ppo_countdown/train_countdown.yaml b/examples/ppo_countdown/train_countdown.yaml index 53a1401c4b..70872fd7f1 100644 --- a/examples/ppo_countdown/train_countdown.yaml +++ b/examples/ppo_countdown/train_countdown.yaml @@ -179,7 +179,6 @@ trainer: save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if - resume_from_path: False test_freq: 100 critic_warmup: 0 default_hdfs_dir: null diff --git a/pyproject.toml b/pyproject.toml index 717a88a00c..8a6f226c77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "math_verify", "ninja", "fire", + "streamlit", "flask", "requests", "tensorboard", diff --git a/tests/common/tmp/template_config.yaml b/tests/common/tmp/template_config.yaml index 2cae62acdf..e163623680 100644 --- a/tests/common/tmp/template_config.yaml +++ b/tests/common/tmp/template_config.yaml @@ -1,7 +1,7 @@ mode: both data: dataset_path: '' - total_epoch: 1 + total_epochs: 1 batch_size: 32 train_split: 'train' eval_split: '' diff --git a/trinity/common/config.py b/trinity/common/config.py index 8fdd9f8e27..587c7c5fe8 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -69,7 +69,7 @@ class DataConfig: max_retry_interval: int = 1 # downstream loading related - total_epoch: int = 1 + total_epochs: int = 1 batch_size: int = 1 default_workflow_type: str = "" default_reward_fn_type: str = "" @@ -101,7 +101,7 @@ class DatasetConfig: name: str storage_type: StorageType - algorithm_type: AlgorithmType + algorithm_type: AlgorithmType = AlgorithmType.PPO path: Optional[str] = None kwargs: Dict[str, Any] = field(default_factory=dict) @@ -143,7 +143,7 @@ class ExplorerConfig: # For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num` runner_num: int = 1 - # repeat each task for `repeat_times` times (for GPRO-like algrorithms) + # repeat each task for `repeat_times` times (for GPRO-like algorithms) repeat_times: int = 1 # for rollout tokneize @@ -177,7 +177,7 @@ class TrainerConfig: trainer_config_path: str = "" eval_interval: int = 100 enable_preview: bool = True # enable rollout preview in wandb - trainer_config: Any = None + trainer_config: Any = field(default_factory=dict) # train algorithm algorithm_type: AlgorithmType = AlgorithmType.PPO @@ -266,21 +266,29 @@ def _check_buffer(self) -> None: else: if self.buffer.train_dataset is None: raise ValueError("buffer.train_dataset is required when mode is not 'both'") - if self.buffer.train_dataset.algorithm_type != self.trainer.algorithm_type: - raise ValueError( - f"buffer.train_dataset.algorithm_type ({self.buffer.train_dataset.algorithm_type}) " - f"is not consistent with trainer.algorithm_type ({self.trainer.algorithm_type})" - ) + self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type + if self.buffer.sft_warmup_dataset is not None: + self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times def check_and_update(self) -> None: """Check and update the config.""" if self.trainer.trainer_type == "verl": - from trinity.common.verl_config import load_config - - if not os.path.isfile(self.trainer.trainer_config_path): - raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}") - self.trainer.trainer_config = load_config(self.trainer.trainer_config_path) + if self.trainer.trainer_config: + from trinity.common.verl_config import veRLConfig + + trainer_config_schema = OmegaConf.structured(veRLConfig) + trainer_config = OmegaConf.merge(trainer_config_schema, self.trainer.trainer_config) + self.trainer.trainer_config = OmegaConf.to_object(trainer_config) + else: + if os.path.isfile(self.trainer.trainer_config_path): + from trinity.common.verl_config import load_config + + self.trainer.trainer_config = load_config(self.trainer.trainer_config_path) + else: + raise ValueError( + f"Invalid trainer config path: {self.trainer.trainer_config_path}" + ) else: raise ValueError(f"Invalid trainer type: {self.trainer_type}") diff --git a/trinity/common/task.py b/trinity/common/task.py index bef638b6c9..5f5309565e 100644 --- a/trinity/common/task.py +++ b/trinity/common/task.py @@ -121,7 +121,7 @@ class TaskSet: task_type: Optional[TaskType] = None default_index: int = 0 default_epoch: int = 0 - total_epoch: int = 1 + total_epochs: int = 1 _tasks: Iterator[Task] = None _index: int = 0 _epoch: int = 0 @@ -160,7 +160,7 @@ def load( task_type=task_type, default_index=latest_task_index % dataset_len, default_epoch=latest_task_index // dataset_len, - total_epoch=config.total_epoch if task_type == TaskType.EXPLORE else 1, + total_epochs=config.total_epochs if task_type == TaskType.EXPLORE else 1, ) def __iter__(self) -> Iterator[Task]: @@ -189,7 +189,7 @@ def epoch(self) -> int: def __next__(self) -> Task: """Iterate through the tasks in the taskset.""" - if self._epoch >= self.total_epoch: + if self._epoch >= self.total_epochs: raise StopIteration try: @@ -204,7 +204,7 @@ def __next__(self) -> Task: # Reset the task generator and increment the epoch self._epoch += 1 self._index += 1 - if self._epoch >= self.total_epoch: + if self._epoch >= self.total_epochs: raise StopIteration self._tasks = task_generator( self.dataset, diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index a9d7dd6cb4..966fde0391 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -83,13 +83,13 @@ class Actor: ppo_epochs: int = 1 shuffle: bool = False ulysses_sequence_parallel_size: int = 1 + checkpoint: Checkpoint = field(default_factory=Checkpoint) optim: Optim = field(default_factory=Optim) fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) alg_type: str = "ppo" # ppo / opmd / pairwise_opmd tau: float = 0.001 # strength of regularization w.r.t. old / ref policy opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd use_uid: bool = False # True / False, applicable to pairwise_opmd - checkpoint: Checkpoint = field(default_factory=Checkpoint) @dataclass @@ -205,6 +205,8 @@ class CustomRewardFunction: class KL_Ctrl: type: str = "fixed" kl_coef: float = 0.001 + horizon: float = 10000 + target_kl: float = 0.1 @dataclass @@ -212,6 +214,8 @@ class Algorithm: gamma: float = 1.0 lam: float = 1.0 adv_estimator: str = "gae" + norm_adv_by_std_in_grpo: bool = True + use_kl_in_reward: bool = False kl_penalty: str = "kl" kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl) @@ -229,7 +233,7 @@ class Trainer: n_gpus_per_node: int = 0 save_freq: int = 0 resume_mode: str = "auto" - resume_from_path: bool = False + resume_from_path: str = "" test_freq: int = 0 critic_warmup: int = 0 default_hdfs_dir: Optional[str] = None @@ -300,8 +304,10 @@ def synchronize_config(self, config: Config) -> None: self.actor_rollout_ref.rollout.temperature = config.explorer.temperature self.actor_rollout_ref.rollout.n = config.explorer.repeat_times batch_size_per_gpu = self.buffer.read_batch_size // world_size - self.actor_rollout_ref.actor.alg_type = config.trainer.algorithm_type.value - print(f"using algorithm type: {self.actor_rollout_ref.actor.alg_type}") + self.actor_rollout_ref.actor.alg_type = ( + config.trainer.algorithm_type.value + ) # TODO: refactor `alg_type` + # print(f"using algorithm type: {self.actor_rollout_ref.actor.alg_type}") if self.actor_rollout_ref.actor.alg_type == "dpo": # for DPO print("Warning: DPO micro batch size is doubled for computing loss.") diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index c4436705ab..13ade4b011 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -1,838 +1,1353 @@ +import copy import os +from typing import List import streamlit as st import yaml -from verl.trainer.ppo.ray_trainer import AdvantageEstimator +from trinity.common.constants import AlgorithmType, MonitorType, StorageType from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.workflows.workflow import WORKFLOWS +from trinity.trainer.verl.ray_trainer import AdvantageEstimator class ConfigManager: def __init__(self): + self._init_default_config() + self.unfinished_fields = set() st.set_page_config(page_title="Trainer Config Generator", page_icon=":robot:") st.title("Trainer Config Generator") - self.reset_config() - self.unfinished_flag = False + if "_init_config_manager" not in st.session_state: + self.reset_session_state() + self.maintain_session_state() + mode = st.pills( + "Select Mode", + options=["Beginner Mode", "Expert Mode"], + default="Beginner Mode", + label_visibility="collapsed", + ) + if mode == "Beginner Mode": + self.beginner_mode() + else: + self.expert_mode() + self.generate_config() - def reset_config(self): - pass + def _init_default_config(self): + self.default_config = { + "_init_config_manager": True, + "project": "Trinity-RFT", + "exp_name": "qwen2.5-1.5B", + "monitor_type": MonitorType.WANDB.value, + # Model Configs + "model_path": "", + "critic_model_path": "", + "checkpoint_path": "", + "node_num": 1, + "gpu_per_node": 8, + "total_gpu_num": 8, + "trainer_gpu_num": 6, + "max_prompt_tokens": 1024, + "max_response_tokens": 1024, + # Data and Buffer Configs + "total_epochs": 20, + "task_num_per_batch": 6, + "dataset_path": "", + "subset_name": None, + "train_split": "train", + "eval_split": "", + "prompt_key": "question", + "response_key": "answer", + "default_workflow_type": "math_workflow", + "default_reward_fn_type": "math_reward", + "storage_type": StorageType.QUEUE.value, + "db_url": "", + "max_retry_times": 3, + "max_retry_interval": 1, + "sft_warmup_dataset_path": "", + "sft_warmup_train_split": "train", + "sft_warmup_eval_split": "", + "sft_warmup_prompt_key": "question", + "sft_warmup_response_key": "answer", + # Explorer and Sync Configs + "engine_type": "vllm_async", + "engine_num": 2, + "tensor_parallel_size": 1, + "repeat_times": 1, + "sync_method": "online", + "sync_iteration_interval": 10, + "runner_num": 32, + "max_pending_requests": 32, + "max_waiting_steps": 4, + "dtype": "bfloat16", + "backend": "nccl", + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "seed": 42, + "logprobs": 0, + "enable_prefix_caching": False, + "enforce_eager": True, + # Trainer Configs + "trainer_type": "verl", + "algorithm_type": AlgorithmType.PPO.value, + "sft_warmup_iteration": 0, + "eval_interval": 1000, + # veRL Trainer Configs + "training_args": [ + "balance_batch", + "gradient_checkpointing", + "remove_padding", + "dynamic_bsz", + ], + "save_freq": 100, + "training_strategy": "fsdp", + "param_offload": False, + "optimizer_offload": False, + "resume_mode": "auto", + "resume_from_path": "", + "critic_warmup": 0, + "total_training_steps": None, + "default_hdfs_dir": None, + "remove_previous_ckpt_in_save": False, + "del_local_ckpt_after_load": False, + "max_actor_ckpt_to_keep": None, + "max_critic_ckpt_to_keep": None, + "gamma": 1.0, + "lam": 1.0, + "adv_estimator": "gae", + "norm_adv_by_std_in_grpo": True, + "use_kl_in_reward": False, + "kl_penalty": "low_var_kl", + "kl_ctrl_type": "fixed", + "kl_ctrl_coef": 0.001, + "horizon": 10000, + "target_kl": 0.1, + "actor_ppo_micro_batch_size_per_gpu": 4, + "ref_log_prob_micro_batch_size_per_gpu": 8, + "actor_ulysses_sequence_parallel_size": 1, + "actor_lr": 1e-6, + "actor_warmup_style": "constant", + "actor_lr_warmup_steps_ratio": 0.0, + "actor_tau": 0.0, + "actor_opmd_baseline": "mean", + "actor_use_uid": False, + "actor_grad_clip": 1.0, + "actor_clip_ratio": 0.2, + "actor_entropy_coeff": 0.001, + "actor_use_kl_loss": True, + "actor_kl_loss_coef": 0.001, + "actor_kl_loss_type": "low_var_kl", + "actor_checkpoint": ["model", "hf_model", "optimizer", "extra"], + "critic_lr": 1e-6, + "critic_warmup_style": "constant", + "critic_lr_warmup_steps_ratio": 0.0, + "critic_grad_clip": 1.0, + "critic_cliprange_value": 0.5, + "critic_ppo_micro_batch_size_per_gpu": 8, + "critic_ulysses_sequence_parallel_size": 1, + "training_mode": "PPO", + } - def set_value(self, key, value): - st.session_state[key] = value + def reset_session_state(self): + for key, value in self.default_config.items(): + st.session_state[key] = value - def beginer_mode(self): - st.write("Work in progress...") + def maintain_session_state(self): + for key in self.default_config: + st.session_state[key] = st.session_state[key] - def expert_mode(self): # noqa: C901 - model_tab, buffer_tab, connector_tab, trainer_tab = st.tabs( - ["Model", "Buffer", "Explorer and Synchronizer", "Trainer"] + def _set_project(self): + st.text_input("Project", key="project") + + def _set_name(self): + st.text_input("Experiment Name", key="exp_name") + + def _set_monitor_type(self): + st.selectbox( + "Monitor Type", + options=[monitor_type.value for monitor_type in MonitorType], + key="monitor_type", ) - with model_tab: - project_col, name_col = st.columns([1, 3]) - project = project_col.text_input("Project", "Trinity-RFT") - name = name_col.text_input("Experiment Name", "qwen2.5-1.5B") - - model_path = st.text_input("Model Path", "") - if not model_path.strip(): - self.unfinished_flag = True - st.warning("Please input model path") - critic_model_path = st.text_input("Critic Model Path (defaults to `model_path`)", "") - if not critic_model_path.strip(): - critic_model_path = model_path - - checkpoint_path = st.text_input("Checkpoint Path", "") - if not checkpoint_path.strip(): - self.unfinished_flag = True - st.warning("Please input checkpoint path") + def _set_model_path(self): + st.text_input("Model Path", key="model_path") + if not st.session_state["model_path"].strip(): + self.unfinished_fields.add("model_path") + st.warning("Please input model path.") + + def _set_critic_model_path(self): + st.text_input( + "Critic Model Path (defaults to `model_path`)", + key="critic_model_path", + ) + + def _set_checkpoint_path(self): + st.text_input("Checkpoint Path", key="checkpoint_path") + if not st.session_state["checkpoint_path"].strip(): # TODO: may auto generate + self.unfinished_fields.add("checkpoint_path") + st.warning("Please input checkpoint path.") + elif not os.path.isabs(st.session_state["checkpoint_path"].strip()): + self.unfinished_fields.add("checkpoint_path") + st.warning("Please input an absolute path.") + + def _set_node_num(self): + st.number_input("Node Num", key="node_num", min_value=1, on_change=self._set_total_gpu_num) + + def _set_gpu_per_node(self): + st.number_input( + "GPU Per Node", + key="gpu_per_node", + min_value=1, + max_value=8, + on_change=self._set_total_gpu_num, + ) + + def _set_total_gpu_num(self): + st.session_state["total_gpu_num"] = ( + st.session_state["gpu_per_node"] * st.session_state["node_num"] + ) + self._set_trainer_gpu_num() + + def _set_trainer_gpu_num(self): + st.session_state["trainer_gpu_num"] = ( + st.session_state["total_gpu_num"] + - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] + ) + + def _set_max_prompt_tokens(self): + st.number_input("Max Prompt Tokens", key="max_prompt_tokens", min_value=1) + + def _set_max_response_tokens(self): + st.number_input("Max Response Tokens", key="max_response_tokens", min_value=1) + + def _set_total_epochs(self): + st.number_input("Total Epochs", key="total_epochs", min_value=1) + + @property + def _str_for_task_num_per_batch(self): + return ( + f"Please ensure that `task_num_per_batch` can be divided by " + f"`gpu_per_node * node_num - engine_num * tensor_parallel_size` " + f"= {st.session_state['trainer_gpu_num']}" + ) + + def _set_task_num_per_batch(self): + trainer_gpu_num = st.session_state["trainer_gpu_num"] + if st.session_state["task_num_per_batch"] < trainer_gpu_num: + st.session_state["task_num_per_batch"] = trainer_gpu_num + st.number_input( + "Task Num Per Batch", + key="task_num_per_batch", + min_value=trainer_gpu_num, + step=trainer_gpu_num, + help=self._str_for_task_num_per_batch, + ) + + def _check_task_num_per_batch(self): + if st.session_state["task_num_per_batch"] % st.session_state["trainer_gpu_num"] != 0: + self.unfinished_fields.add("task_num_per_batch") + st.warning(self._str_for_task_num_per_batch) + + def _set_dataset_path(self): + st.text_input("Dataset Path", key="dataset_path") + if not st.session_state["dataset_path"].strip(): + self.unfinished_fields.add("dataset_path") + st.warning("Please input dataset path.") + + def _set_dataset_args(self): + if st.session_state["dataset_path"] and "://" not in st.session_state["dataset_path"]: + subset_name_col, train_split_col, eval_split_col = st.columns(3) + subset_name_col.text_input("Subset Name", key="subset_name") + train_split_col.text_input("Train Split", key="train_split") + eval_split_col.text_input("Eval Split", key="eval_split") + prompt_key_col, response_key_col = st.columns(2) + prompt_key_col.text_input("Prompt Key", key="prompt_key") + response_key_col.text_input("Response Key", key="response_key") + + def _set_default_workflow_type(self): + st.selectbox( + "Default Workflow Type", + WORKFLOWS.modules.keys(), + key="default_workflow_type", + help=r"""`simple_workflow`: call 'model.chat()' to get responses. + +`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. + +Other workflows: conduct multi-turn task for the given dataset. +""", + ) + + def _set_default_reward_fn_type(self): + st.selectbox( + "Default Reward Fn Type", + REWARD_FUNCTIONS.modules.keys(), + key="default_reward_fn_type", + help=r"""`accuracy_reward`: check the accuracy for math problems. + +`format_reward`: check if the response matches the format (default: `** *`). + +`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1). +""", + ) + + def _set_storage_type(self): + st.selectbox( + "Storage Type", + [storage_type.value for storage_type in StorageType], + key="storage_type", + ) + + def _set_db_url(self): + st.text_input( + "DB URL", + key="db_url", + help=r"Default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project_name, experiment_name)}/data.db`", + ) + + def _set_max_retry_times(self): + st.number_input("Max Retry Times", key="max_retry_times", min_value=1) + + def _set_max_retry_interval(self): + st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1) + + def _check_sft_warmup_dataset_path(self): + if st.session_state["sft_warmup_iteration"]: + if not st.session_state["sft_warmup_dataset_path"].strip(): + self.unfinished_fields.add("sft_warmup_dataset_path") + st.warning( + "Please input SFT warmup dataset path when `sft_warmup_iteration` is not 0" + ) + + def _set_sft_warmup_dataset_path(self): + st.text_input("SFT Warmup Dataset Path", key="sft_warmup_dataset_path") + self._check_sft_warmup_dataset_path() + + def _set_sft_warmup_dataset_args(self): + if ( + st.session_state["sft_warmup_dataset_path"] + and "://" not in st.session_state["sft_warmup_dataset_path"] + ): # TODO ( - node_num_col, - gpu_per_node_col, - max_prompt_tokens_col, - max_response_tokens_col, + sft_warmup_train_split_col, + sft_warmup_eval_split_col, + sft_warmup_prompt_key_col, + sft_warmup_response_key_col, ) = st.columns(4) - node_num = node_num_col.number_input("Node Num", value=1, min_value=1) - gpu_per_node = gpu_per_node_col.number_input( - "GPU Per Node", value=8, min_value=1, max_value=8 - ) - max_prompt_tokens = max_prompt_tokens_col.number_input( - "Max Prompt Tokens", value=256, min_value=1 - ) - max_response_tokens = max_response_tokens_col.number_input( - "Max Response Tokens", value=1024, min_value=1 + sft_warmup_train_split_col.text_input("SFT Train Split", key="sft_warmup_train_split") + sft_warmup_eval_split_col.text_input("SFT Eval Split", key="sft_warmup_eval_split") + sft_warmup_prompt_key_col.text_input("SFT Prompt Key", key="sft_warmup_prompt_key") + sft_warmup_response_key_col.text_input( + "SFT Response Key", key="sft_warmup_response_key" ) - with buffer_tab: - total_epoch_col, batch_size_per_gpu_col = st.columns(2) - total_epoch = total_epoch_col.number_input("Total Epoch", value=20, min_value=1) - batch_size_per_gpu = batch_size_per_gpu_col.number_input( - "Batch Size Per GPU", value=1, min_value=1 - ) + def _set_engine_type(self): + st.selectbox("Explorer Engine Type", ["vllm_async", "vllm"], key="engine_type") - dataset_path = st.text_input("Dataset Path", "") - if not dataset_path.strip(): - self.unfinished_flag = True - st.warning("Please input dataset path") - - if dataset_path and "://" not in dataset_path: - train_split_col, eval_split_col, prompt_key_col, response_key_col = st.columns(4) - train_split = train_split_col.text_input("Train Split", "train") - eval_split = eval_split_col.text_input("Eval Split", "") - prompt_key = prompt_key_col.text_input("Prompt Key", "question") - response_key = response_key_col.text_input("Response Key", "answer") - - default_workflow_type_col, default_reward_fn_type_col, storage_type_col = st.columns(3) - default_workflow_type = default_workflow_type_col.selectbox( - "Default Workflow Type", WORKFLOWS.modules.keys(), index=1 - ) - default_reward_fn_type = default_reward_fn_type_col.selectbox( - "Default Reward Fn Type", REWARD_FUNCTIONS.modules.keys(), index=3 - ) - storage_type = storage_type_col.selectbox( - "Storage Type", ["sql", "redis", "queue"], index=2 - ) + @property + def _str_for_engine_num_and_tp_size(self): + return r"""and it must meet the following constraints: +```python +assert engine_num * tensor_parallel_size < gpu_per_node * node_num +if node_num > 1: + assert gpu_per_node % tensor_parallel_size == 0 + assert engine_num * tensor_parallel_size % gpu_per_node == 0 +```""" - buffer_advanced_tab = st.expander("Advanced Config") - with buffer_advanced_tab: - db_url = st.text_input( - "DB URL", - "", - help=r"Default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project, name)}/data.db`", - ) - if not db_url.strip(): - db_url = rf"sqlite:///{os.path.join(checkpoint_path, '.cache', project, name)}/data.db" + def _set_engine_num(self): + total_gpu_num = st.session_state["gpu_per_node"] * st.session_state["node_num"] + max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"] + if st.session_state["engine_num"] > max_engine_num: + st.session_state["engine_num"] = max_engine_num + self._set_trainer_gpu_num() + st.number_input( + "Engine Num", + key="engine_num", + min_value=1, + max_value=max_engine_num, + help=f"`engine_num` is used to set the quantity of inference engines, " + f"{self._str_for_engine_num_and_tp_size}", + on_change=self._set_trainer_gpu_num, + ) - max_retry_times_col, max_retry_interval_col = st.columns(2) - max_retry_times = max_retry_times_col.number_input( - "Max Retry Times", value=3, min_value=1 + def _set_tensor_parallel_size(self): + total_gpu_num = st.session_state["gpu_per_node"] * st.session_state["node_num"] + max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"] + if st.session_state["tensor_parallel_size"] > max_tensor_parallel_size: + st.session_state["tensor_parallel_size"] = max_tensor_parallel_size + self._set_trainer_gpu_num() + st.number_input( + "Tensor Parallel Size", + key="tensor_parallel_size", + min_value=1, + max_value=max_tensor_parallel_size, + help=f"`tensor_parallel_size` is used to set the tensor parallel size of inference engines, " + f"{self._str_for_engine_num_and_tp_size}", + on_change=self._set_trainer_gpu_num, + ) + + def _check_engine_num_and_tp_size(self): + node_num = st.session_state["node_num"] + gpu_per_node = st.session_state["gpu_per_node"] + engine_num = st.session_state["engine_num"] + tensor_parallel_size = st.session_state["tensor_parallel_size"] + if node_num > 1: + if gpu_per_node % tensor_parallel_size != 0: + self.unfinished_fields.add("tensor_parallel_size") + st.warning( + "Please ensure that `tensor_parallel_size` is a factor of `gpu_per_node` when `node_num > 1`." ) - max_retry_interval = max_retry_interval_col.number_input( - "Max Retry Interval", value=1, min_value=1 + if engine_num * tensor_parallel_size % gpu_per_node != 0: + self.unfinished_fields.add("engine_num") + st.warning( + "Please ensure that `engine_num * tensor_parallel_size` can be divided by `gpu_per_node` when `node_num > 1`." ) - sft_warmup_dataset_path = st.text_input("SFT Warmup Dataset Path", "") - if sft_warmup_dataset_path and "://" not in sft_warmup_dataset_path: # TODO - ( - sft_warmup_train_split_col, - sft_warmup_eval_split_col, - sft_warmup_prompt_key_col, - sft_warmup_response_key_col, - ) = st.columns(4) - sft_warmup_train_split = sft_warmup_train_split_col.text_input( # noqa: F841 - "SFT Train Split", "train" - ) - sft_warmup_eval_split = sft_warmup_eval_split_col.text_input( # noqa: F841 - "SFT Eval Split", "" - ) - sft_warmup_prompt_key = sft_warmup_prompt_key_col.text_input( # noqa: F841 - "SFT Prompt Key", "question" - ) - sft_warmup_response_key = sft_warmup_response_key_col.text_input( # noqa: F841 - "SFT Response Key", "answer" - ) - else: - sft_warmup_train_split = "" # noqa: F841 - sft_warmup_eval_split = "" # noqa: F841 - sft_warmup_prompt_key = "" # noqa: F841 - sft_warmup_response_key = "" # noqa: F841 + def _set_repeat_times(self): + if st.session_state["algorithm_type"] == AlgorithmType.OPMD.value or st.session_state[ + "adv_estimator" + ] in [ + AdvantageEstimator.GRPO.value, + AdvantageEstimator.RLOO.value, + ]: + min_repeat_times = 2 + else: + min_repeat_times = 1 + if st.session_state["repeat_times"] < min_repeat_times: + st.session_state["repeat_times"] = min_repeat_times + st.number_input( + "Repeat Times", + key="repeat_times", + min_value=min_repeat_times, + help="`repeat_times` is used to set how many experiences each task can generate, " + "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + ) - with connector_tab: - ( - engine_type_col, - engine_num_col, - tensor_parallel_size_col, - repeat_times_col, - ) = st.columns(4) - engine_type = engine_type_col.selectbox( - "Explorer Engine Type", ["vllm_async", "vllm"], index=0 - ) - if "engine_num" not in st.session_state: - st.session_state.engine_num = 2 - old_engine_num = min(st.session_state.engine_num, gpu_per_node * node_num) - engine_num = engine_num_col.number_input( - "Engine Num", - value=old_engine_num, - min_value=1, - max_value=gpu_per_node * node_num, - help="cannot exceed `gpu_per_node` * `node_num`", - ) - st.session_state.engine_num = engine_num - tensor_parallel_size = tensor_parallel_size_col.number_input( - "Tensor Parallel Size", value=1, min_value=1, max_value=8 + def _set_sync_method(self): + st.selectbox( + "Sync Method", + ["online", "offline"], + key="sync_method", + help="""`online`: the explorer and trainer sync model weights once every `sync_iteration_interval` steps. + +`offline`: the trainer saves the model checkpoint, and the explorer loads it at `sync_iteration_interval`.""", + ) + + def _set_sync_iteration_interval(self): + st.number_input( + "Sync Iteration Interval", + key="sync_iteration_interval", + min_value=1, + help="""The iteration interval at which the `explorer` and `trainer` synchronize model weight.""", + ) + + def _set_runner_num(self): + st.number_input("Runner Num", key="runner_num", min_value=1) + + def _set_max_pending_requests(self): + st.number_input("Max Pending Requests", key="max_pending_requests", min_value=1) + + def _set_max_waiting_steps(self): + st.number_input("Max Waiting Steps", key="max_waiting_steps", min_value=1) + + def _set_dtype(self): + st.selectbox("Dtype", ["float16", "bfloat16", "float32"], key="dtype") + + def _set_backend(self): + st.selectbox("Backend", ["nccl"], key="backend") + + def _set_temperature(self): + st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) + + def _set_top_p(self): + st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0) + + def _set_top_k(self): + st.number_input( + "Top-k", + key="top_k", + min_value=-1, + max_value=512, + help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.", + ) + + def _set_seed(self): + st.number_input("Seed", key="seed", step=1) + + def _set_logprobs(self): + st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) + + def _set_enable_prefix_caching(self): + st.checkbox("Enable Prefix Caching", key="enable_prefix_caching") + + def _set_enforce_eager(self): + st.checkbox("Enforce Eager", key="enforce_eager") + + def _set_trainer_type(self): + st.selectbox("Trainer Type", ["verl"], key="trainer_type") + + def _set_algorithm_type(self): + st.selectbox( + "Algorithm Type", + [AlgorithmType.PPO.value, AlgorithmType.DPO.value, AlgorithmType.OPMD.value], + key="algorithm_type", + ) + + def _set_sft_warmup_iteration(self): + st.number_input("SFT Warmup Iteration", key="sft_warmup_iteration", min_value=0) + + def _set_eval_interval(self): + st.number_input("Eval Interval", key="eval_interval", min_value=1) + + def _set_training_args(self): + st.multiselect( + "Training Args", + [ + "balance_batch", + "gradient_checkpointing", + "remove_padding", + "dynamic_bsz", + ], + key="training_args", + ) + + def _set_save_freq(self): + if st.session_state["sync_method"] == "online": + freeze_save_freq = False + else: + st.session_state["save_freq"] = st.session_state["sync_iteration_interval"] + freeze_save_freq = True + st.number_input( + "Save Freq", + key="save_freq", + min_value=1, + help="Set to `sync_iteration_interval` when `sync_method` is `offline`", + disabled=freeze_save_freq, + ) + + def _set_training_strategy(self): + st.selectbox( + "Training Strategy", + ["fsdp", "megatron"], + key="training_strategy", + help="megatron is not tested", + ) + + def _set_param_offload(self): + st.checkbox("FSDP Param Offload", key="param_offload") + + def _set_optimizer_offload(self): + st.checkbox("FSDP Optimizer Offload", key="optimizer_offload") + + def _set_resume_mode(self): + st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], key="resume_mode") + + def _set_resume_from_path(self): + if st.session_state["resume_mode"] == "resume_path": + st.text_input("Resume Path", key="resume_from_path") + if ( + not st.session_state["resume_from_path"].strip() + or "global_step_" not in st.session_state["resume_from_path"] + ): + self.unfinished_fields.add("resume_from_path") + st.warning("Please input a valid resume path when `resume_mode == resume_path`") + + def _set_critic_warmup(self): + st.number_input("Critic Warmup Iteration", key="critic_warmup", min_value=0) + + def _set_total_training_steps(self): + st.number_input("Total Training Steps", key="total_training_steps", min_value=1) + + def _set_default_hdfs_dir(self): + st.text_input("Default HDFS Dir", key="default_hdfs_dir") + + def _set_remove_previous_ckpt_in_save(self): + st.checkbox("Remove Previous Checkpoint in Save", key="remove_previous_ckpt_in_save") + + def _set_del_local_ckpt_after_load(self): + st.checkbox("Delete Local Checkpoint After Load", key="del_local_ckpt_after_load") + + def _set_max_actor_ckpt_to_keep(self): + st.number_input("Max Actor Checkpoint to Keep", key="max_actor_ckpt_to_keep", min_value=1) + + def _set_max_critic_ckpt_to_keep(self): + st.number_input("Max Critic Checkpoint to Keep", key="max_critic_ckpt_to_keep", min_value=1) + + def _set_gamma(self): + st.number_input("Gamma", key="gamma") + + def _set_lam(self): + st.number_input("Lambda", key="lam") + + def _set_adv_estimator(self): + st.selectbox( + "Advantage Estimator", + [member.value for member in AdvantageEstimator], + key="adv_estimator", + ) + + def _set_norm_adv_by_std_in_grpo(self): + st.checkbox("Norm Adv by Std in GRPO", key="norm_adv_by_std_in_grpo") + + def _set_use_kl_in_reward(self): + st.checkbox("Use KL in Reward", key="use_kl_in_reward") + + def _set_kl_penalty(self): + st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], key="kl_penalty") + + def _set_kl_ctrl_type(self): + st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], key="kl_ctrl_type") + + def _set_kl_ctrl_coef(self): + st.number_input("KL Ctrl Coef", key="kl_ctrl_coef", format="%.1e") + + def _set_horizon(self): + st.number_input("Horizon", key="horizon", min_value=1.0) + + def _set_target_kl(self): + st.number_input("Target KL", key="target_kl", format="%.1e") + + def _set_actor_ppo_micro_batch_size_per_gpu(self): + st.number_input( + "Micro Batch Size Per GPU for Actor", + key="actor_ppo_micro_batch_size_per_gpu", + min_value=1, + ) + + def _set_ref_log_prob_micro_batch_size_per_gpu(self): + st.number_input( + "Micro Batch Size Per GPU for Ref", + key="ref_log_prob_micro_batch_size_per_gpu", + min_value=1, + ) + + def _set_actor_ulysses_sequence_parallel_size(self): + st.number_input( + "Ulysses Sequence Parallel Size", + key="actor_ulysses_sequence_parallel_size", + min_value=1, + max_value=8, + ) + + def _set_actor_lr(self): + st.number_input( + "Learning Rate for Actor", + key="actor_lr", + min_value=1e-7, + max_value=1e-3, + format="%.1e", + ) + + def _set_actor_warmup_style(self): + st.selectbox( + "LR Warmup Style for Actor", + ["constant", "cosine"], + key="actor_warmup_style", + ) + + def _set_actor_lr_warmup_steps_ratio(self): + st.number_input( + "LR Warmup Steps Ratio for Actor", + key="actor_lr_warmup_steps_ratio", + min_value=0.0, + max_value=1.0, + ) + + def _set_actor_grad_clip(self): + st.number_input("Grad Clip", key="actor_grad_clip", min_value=0.0, max_value=1.0) + + def _set_actor_clip_ratio(self): + st.number_input("Clip Ratio", key="actor_clip_ratio", min_value=0.0, max_value=1.0) + + def _set_actor_entropy_coeff(self): + st.number_input( + "Entropy Coeff", + key="actor_entropy_coeff", + min_value=0.0, + max_value=1.0, + format="%.1e", + ) + + def _set_actor_use_kl_loss(self): + st.checkbox("Use KL Loss", key="actor_use_kl_loss") + + def _set_actor_kl_loss_coef(self): + st.number_input( + "KL Loss Coef", + key="actor_kl_loss_coef", + min_value=0.0, + max_value=1.0, + format="%.1e", + ) + + def _set_actor_kl_loss_type(self): + st.selectbox( + "KL Loss Type", + ["kl", "abs", "mse", "low_var_kl"], + key="actor_kl_loss_type", + ) + + def _set_actor_tau(self): + st.number_input( + "Tau for OPMD", + key="actor_tau", + min_value=0.0, + format="%.1e", + ) + + def _set_actor_opmd_baseline(self): + st.selectbox( + "OPMD Baseline", + ["mean", "logavgexp"], + key="actor_opmd_baseline", + ) + + def _set_actor_use_uid(self): + st.checkbox("Use UID for OPMD", key="actor_use_uid") + + def _set_actor_checkpoint(self): + st.multiselect( + "Checkpoint", + ["model", "hf_model", "optimizer", "extra"], + key="actor_checkpoint", + ) + + def _set_critic_ppo_micro_batch_size_per_gpu(self): + st.number_input( + "Micro Batch Size Per GPU for Critic", + key="critic_ppo_micro_batch_size_per_gpu", + min_value=1, + ) + + def _set_critic_ulysses_sequence_parallel_size(self): + st.number_input( + "Ulysses Sequence Parallel Size", + key="critic_ulysses_sequence_parallel_size", + min_value=1, + max_value=8, + ) + + def _set_critic_lr(self): + st.number_input( + "Learning Rate for Critic", + key="critic_lr", + min_value=1e-7, + max_value=1e-3, + format="%.1e", + ) + + def _set_critic_warmup_style(self): + st.selectbox( + "LR Warmup Style for Critic", + ["constant", "cosine"], + key="critic_warmup_style", + ) + + def _set_critic_lr_warmup_steps_ratio(self): + st.number_input( + "LR Warmup Steps Ratio for Critic", + key="critic_lr_warmup_steps_ratio", + min_value=0.0, + max_value=1.0, + ) + + def _set_critic_grad_clip(self): + st.number_input( + "Grad Clip for Critic", + key="critic_grad_clip", + min_value=0.0, + max_value=1.0, + ) + + def _set_critic_cliprange_value(self): + st.number_input( + "Cliprange Value", + key="critic_cliprange_value", + min_value=0.0, + max_value=1.0, + ) + + def _set_training_mode(self): + st.selectbox("Training Mode", ["PPO", "GRPO", "DPO", "OPMD"], key="training_mode") + + if st.session_state["training_mode"] == "PPO": + st.session_state["algorithm_type"] = AlgorithmType.PPO.value + st.session_state["adv_estimator"] = "gae" + elif st.session_state["training_mode"] == "GRPO": + st.session_state["algorithm_type"] = AlgorithmType.PPO.value + st.session_state["adv_estimator"] = "grpo" + elif st.session_state["training_mode"] == "DPO": + st.session_state["algorithm_type"] = AlgorithmType.DPO.value + st.session_state["adv_estimator"] = "grpo" + elif st.session_state["training_mode"] == "OPMD": + st.session_state["algorithm_type"] = AlgorithmType.OPMD.value + st.session_state["adv_estimator"] = "grpo" + + def _set_configs_with_st_columns( + self, config_names: List[str], columns_config: List[int] = None + ): + if columns_config is None: + columns_config = len(config_names) + columns = st.columns(columns_config) + for col, config_name in zip(columns, config_names): + with col: + getattr(self, f"_set_{config_name}")() + + def beginner_mode(self): + st.header("Essential Configs") + self._set_configs_with_st_columns(["project", "name"], columns_config=[1, 3]) + + self._set_model_path() + + self._set_checkpoint_path() + + self._set_dataset_path() + + self._set_configs_with_st_columns(["training_mode", "sft_warmup_iteration", "monitor_type"]) + if st.session_state["sft_warmup_iteration"] > 0: + self._set_sft_warmup_dataset_path() + + st.header("Important Configs") + self._set_configs_with_st_columns( + ["node_num", "gpu_per_node", "engine_num", "tensor_parallel_size"] + ) + self._check_engine_num_and_tp_size() + + self._set_configs_with_st_columns( + ["total_epochs", "task_num_per_batch", "max_prompt_tokens", "max_response_tokens"] + ) + self._check_task_num_per_batch() + + self._set_dataset_args() + + if st.session_state["sft_warmup_iteration"] > 0: + self._set_sft_warmup_dataset_args() + + self._set_configs_with_st_columns( + ["default_workflow_type", "default_reward_fn_type", "repeat_times"] + ) + + self._set_configs_with_st_columns(["sync_iteration_interval", "eval_interval", "save_freq"]) + + self._set_actor_use_kl_loss() + if st.session_state["actor_use_kl_loss"]: + self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"]) + + self._set_configs_with_st_columns( + [ + "actor_ppo_micro_batch_size_per_gpu", + "actor_lr", + "ref_log_prob_micro_batch_size_per_gpu", + ] + ) + + use_critic = st.session_state["adv_estimator"] == "gae" # TODO: may apply to expert mode + if use_critic: + self._set_configs_with_st_columns(["critic_ppo_micro_batch_size_per_gpu", "critic_lr"]) + + def _expert_model_part(self): + self._set_configs_with_st_columns(["project", "name"], columns_config=[1, 3]) + + self._set_model_path() + self._set_critic_model_path() + + self._set_checkpoint_path() + + self._set_configs_with_st_columns(["monitor_type", "node_num", "gpu_per_node"]) + self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) + + def _expert_buffer_part(self): + self._set_configs_with_st_columns(["total_epochs", "task_num_per_batch"]) + self._check_task_num_per_batch() + + self._set_dataset_path() + + self._set_dataset_args() + + self._set_configs_with_st_columns( + ["default_workflow_type", "default_reward_fn_type", "storage_type"] + ) + + self.buffer_advanced_tab = st.expander("Advanced Config") + with self.buffer_advanced_tab: + self._set_db_url() + + self._set_configs_with_st_columns(["max_retry_times", "max_retry_interval"]) + + self._set_sft_warmup_dataset_path() + self._set_sft_warmup_dataset_args() + + def _expert_connector_part(self): + self._set_configs_with_st_columns( + ["engine_type", "engine_num", "tensor_parallel_size", "repeat_times"] + ) + self._check_engine_num_and_tp_size() + + self._set_configs_with_st_columns(["sync_method", "sync_iteration_interval"]) + + with st.expander("Advanced Config"): + self._set_configs_with_st_columns( + ["runner_num", "max_pending_requests", "max_waiting_steps", "dtype"] ) - repeat_times = repeat_times_col.number_input("Repeat Times", value=1, min_value=1) - sync_method_col, sync_iteration_interval_col = st.columns(2) - sync_method = sync_method_col.selectbox("Sync Method", ["online", "offline"], index=0) - sync_iteration_interval = sync_iteration_interval_col.number_input( - "Sync Iteration Interval", value=10, min_value=1 + self._set_configs_with_st_columns( + ["backend", "temperature", "top_p", "top_k", "seed", "logprobs"] ) + + self._set_configs_with_st_columns(["enable_prefix_caching", "enforce_eager"]) + + def _expert_trainer_part(self): + self._set_configs_with_st_columns( + ["trainer_type", "algorithm_type", "sft_warmup_iteration", "eval_interval"] + ) + self._check_sft_warmup_dataset_path() + + if st.session_state["trainer_type"] == "verl": + self._expert_verl_trainer_part() + + def _expert_verl_trainer_part(self): + rl_training_tab, rl_algorithm_tab, actor_ref_tab, critic_tab = st.tabs( + [ + "RL Training Config", + "RL Algorithm Config", + "Actor and Ref Config", + "Critic Config", + ] + ) + with rl_training_tab: + st.subheader("RL Training Config") + self._set_training_args() + + self._set_configs_with_st_columns(["save_freq", "training_strategy", "resume_mode"]) + + if st.session_state["training_strategy"] == "fsdp": + self._set_configs_with_st_columns(["param_offload", "optimizer_offload"]) + self._set_resume_from_path() + with st.expander("Advanced Config"): - ( - runner_num_col, - max_pending_requests_col, - max_waiting_steps_col, - dtype_col, - ) = st.columns(4) - runner_num = runner_num_col.number_input("Runner Num", value=32, min_value=1) - max_pending_requests = max_pending_requests_col.number_input( - "Max Pending Requests", value=32, min_value=1 - ) - max_waiting_steps = max_waiting_steps_col.number_input( - "Max Waiting Steps", value=4, min_value=1 - ) - dtype = dtype_col.selectbox("Dtype", ["float16", "bfloat16", "float32"], index=1) - - ( - backend_col, - temperature_col, - top_p_col, - top_k_col, - seed_col, - logprobs_col, - ) = st.columns(6) - backend = backend_col.selectbox("Backend", ["nccl"], index=0) - temperature = temperature_col.number_input( - "Temperature", value=1.0, min_value=0.0, max_value=2.0 + self._set_configs_with_st_columns(["critic_warmup", "total_training_steps"]) + + self._set_default_hdfs_dir() + + self._set_configs_with_st_columns( + ["remove_previous_ckpt_in_save", "del_local_ckpt_after_load"] ) - top_p = top_p_col.number_input("Top P", value=1.0, min_value=0.0, max_value=1.0) - top_k = top_k_col.number_input("Top K", value=1, min_value=1, max_value=512) - seed = seed_col.number_input("Seed", value=42) - logprobs = logprobs_col.number_input("Logprobs", value=0, min_value=0, max_value=20) - - enable_prefix_caching_col, enforce_eager_col = st.columns(2) - enable_prefix_caching = enable_prefix_caching_col.checkbox( - "Enable Prefix Caching", value=False + + self._set_configs_with_st_columns( + ["max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep"] ) - enforce_eager = enforce_eager_col.checkbox("Enforce Eager", value=True) - gpu_num = gpu_per_node * node_num - engine_num + with rl_algorithm_tab: + st.subheader("RL Algorithm Config") + self._set_configs_with_st_columns(["gamma", "lam", "adv_estimator"]) + self._set_configs_with_st_columns(["norm_adv_by_std_in_grpo", "use_kl_in_reward"]) + self._set_configs_with_st_columns(["kl_penalty", "kl_ctrl_type", "kl_ctrl_coef"]) + self._set_configs_with_st_columns(["horizon", "target_kl"]) - with trainer_tab: - trainer_type_col, sft_warmup_iteration_col, eval_interval_col = st.columns(3) - trainer_type = trainer_type_col.selectbox("Trainer Type", ["verl"], index=0) - sft_warmup_iteration = sft_warmup_iteration_col.number_input( - "SFT Warmup Iteration", value=0, min_value=0 + with actor_ref_tab: + st.subheader("Actor Model Config") + self._set_configs_with_st_columns( + [ + "actor_ppo_micro_batch_size_per_gpu", + "ref_log_prob_micro_batch_size_per_gpu", + "actor_ulysses_sequence_parallel_size", + ] ) - if sft_warmup_iteration and not sft_warmup_dataset_path.strip(): - self.unfinished_flag = True - st.warning( - "Please input SFT warmup dataset path when `sft_warmup_iteration` is not 0" - ) - with buffer_advanced_tab: - st.warning( - "Please input SFT warmup dataset path when `sft_warmup_iteration` is not 0" - ) - eval_interval = eval_interval_col.number_input("Eval Interval", value=1000, min_value=1) - if trainer_type == "verl": - trainer_config_path = st.text_input("Trainer Config Path", "") - if not trainer_config_path.strip(): - self.unfinished_flag = True - st.warning("Please input trainer config path") - - rl_training_tab, rl_algorithm_tab, actor_ref_tab, critic_tab = st.tabs( - [ - "RL Training Config", - "RL Algorithm Config", - "Actor and Ref Config", - "Critic Config", - ] + + self._set_configs_with_st_columns( + ["actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio"] + ) + + self._set_configs_with_st_columns( + ["actor_grad_clip", "actor_clip_ratio", "actor_entropy_coeff"] + ) + + self._set_actor_use_kl_loss() + if st.session_state["actor_use_kl_loss"]: + self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"]) + + if st.session_state["algorithm_type"] == "opmd": + self._set_configs_with_st_columns( + ["actor_tau", "actor_opmd_baseline", "actor_use_uid"] ) - with rl_training_tab: - st.subheader("RL Training Config") - training_args = st.multiselect( - "Training Args", - [ - "balance_batch", - "gradient_checkpointing", - "remove_padding", - "dynamic_bsz", - ], - default=[ - "balance_batch", - "gradient_checkpointing", - "remove_padding", - "dynamic_bsz", - ], - ) - balance_batch = "balance_batch" in training_args - enable_gradient_checkpointing = "gradient_checkpointing" in training_args - use_remove_padding = "remove_padding" in training_args - use_dynamic_bsz = "dynamic_bsz" in training_args - - ( - save_freq_col, - training_strategy_col, - resume_mode_col, - ) = st.columns(3) - if "save_freq" not in st.session_state: - st.session_state.save_freq = 100 - if sync_method == "online": - save_freq = save_freq_col.number_input( - "Save Freq", - value=st.session_state.save_freq, - min_value=1, - help="Set to `sync_iteration_interval` when `sync_method` is `offline`", - ) - st.session_state.save_freq = save_freq - else: - st.session_state.save_freq = sync_iteration_interval - save_freq = save_freq_col.number_input( - "Save Freq", - value=st.session_state.save_freq, - min_value=1, - help="Set to `sync_iteration_interval` when `sync_method` is `offline`", - disabled=True, - ) - - training_strategy = training_strategy_col.selectbox( - "Training Strategy", - ["fsdp", "megatron"], - index=0, - help="megatron is not tested", - ) - if training_strategy == "fsdp": - param_offload_col, optimizer_offload_col = st.columns(2) - param_offload = param_offload_col.checkbox( - "FSDP Param Offload", value=False - ) - optimizer_offload = optimizer_offload_col.checkbox( - "FSDP Optimizer Offload", value=False - ) - fsdp_config = { - "wrap_policy": {"min_num_params": 0}, - "param_offload": param_offload, - "optimizer_offload": optimizer_offload, - "fsdp_size": -1, - } - else: - fsdp_config = {} - - resume_mode = resume_mode_col.selectbox( - "Resume Mode", ["disable", "auto", "resume_path"], index=1 - ) - if "resume_from_path" not in st.session_state: - st.session_state.resume_from_path = "" - if resume_mode == "resume_path": - resume_from_path = st.text_input( - "Resume Path", st.session_state.resume_from_path - ) - st.session_state.resume_from_path = resume_from_path - if not resume_from_path.strip() or "global_step_" not in resume_from_path: - self.unfinished_flag = True - st.warning( - "Please input a valid resume path when `resume_mode` is `resume_path`" - ) - else: - resume_from_path = st.session_state.resume_from_path - - with st.expander("Advanced Config"): - critic_warmup_col, total_training_steps_col = st.columns(2) - critic_warmup = critic_warmup_col.number_input( - "Critic Warmup Iteration", value=0, min_value=0 - ) - total_training_steps = total_training_steps_col.number_input( - "Total Training Steps", value=None, min_value=1 - ) - - default_hdfs_dir = st.text_input("Default HDFS Dir", None) - - ( - remove_previous_ckpt_in_save_col, - del_local_ckpt_after_load_col, - ) = st.columns(2) - remove_previous_ckpt_in_save = remove_previous_ckpt_in_save_col.checkbox( - "Remove Previous Checkpoint in Save", value=False - ) - del_local_ckpt_after_load = del_local_ckpt_after_load_col.checkbox( - "Delete Local Checkpoint After Load", value=False - ) - - max_actor_ckpt_to_keep_col, max_critic_ckpt_to_keep_col = st.columns(2) - max_actor_ckpt_to_keep = max_actor_ckpt_to_keep_col.number_input( - "Max Actor Checkpoint to Keep", value=None, min_value=1 - ) - max_critic_ckpt_to_keep = max_critic_ckpt_to_keep_col.number_input( - "Max Critic Checkpoint to Keep", value=None, min_value=1 - ) - - with rl_algorithm_tab: - st.subheader("RL Algorithm Config") - gamma_col, lam_col, adv_estimator_col = st.columns(3) - gamma = gamma_col.number_input("Gamma", value=1.0) - lam = lam_col.number_input("lam", value=1.0) - adv_estimator = adv_estimator_col.selectbox( - "Advantage Estimator", - [member.value for member in AdvantageEstimator], - index=0, - ) - kl_penalty_col, kl_ctrl_type_col, kl_ctrl_coef_col = st.columns(3) - kl_penalty = kl_penalty_col.selectbox( - "KL Penalty", ["kl", "abs", "mse", "low_var_kl"], index=0 - ) - kl_ctrl_type = kl_ctrl_type_col.selectbox( - "KL Ctrl Type", ["fixed", "adaptive"], index=0 - ) - kl_ctrl_coef = kl_ctrl_coef_col.number_input("KL Ctrl Coef", value=0.001) - - with actor_ref_tab: - st.subheader("Actor Model Config") - ( - actor_ppo_micro_batch_size_per_gpu_col, - ref_log_prob_micro_batch_size_per_gpu_col, - actor_ulysses_sequence_parallel_size_col, - ) = st.columns(3) - actor_ppo_micro_batch_size_per_gpu = ( - actor_ppo_micro_batch_size_per_gpu_col.number_input( - "Micro Batch Size Per GPU for Actor", value=4, min_value=1 - ) - ) - ref_log_prob_micro_batch_size_per_gpu = ( - ref_log_prob_micro_batch_size_per_gpu_col.number_input( - "Micro Batch Size Per GPU for Ref", value=8, min_value=1 - ) - ) - actor_ulysses_sequence_parallel_size = ( - actor_ulysses_sequence_parallel_size_col.number_input( - "Ulysses Sequence Parallel Size", value=1, min_value=1, max_value=8 - ) - ) - - ( - actor_lr_col, - actor_warmup_style_col, - actor_lr_warmup_steps_ratio_col, - ) = st.columns(3) - actor_lr = actor_lr_col.number_input( - "Learning Rate for actor", - value=1e-6, - min_value=1e-7, - max_value=1e-3, - format="%.1e", - ) - actor_warmup_style = actor_warmup_style_col.selectbox( - "LR Warmup Style", ["constant", "cosine"], index=0 - ) - actor_lr_warmup_steps_ratio = actor_lr_warmup_steps_ratio_col.number_input( - "LR Warmup Steps Ratio", value=0.0, min_value=0.0, max_value=1.0 - ) - - ( - actor_alg_type_col, - actor_grad_clip_col, - actor_clip_ratio_col, - actor_entropy_coeff_col, - ) = st.columns(4) - actor_alg_type = actor_alg_type_col.selectbox( - "Algorithm Type", ["ppo", "opmd", "pairwise_opmd"], index=0 - ) - if "actor_tau" not in st.session_state: - st.session_state.actor_tau = 0.0 - st.session_state.actor_opmd_baseline = "mean" - st.session_state.actor_use_uid = False - if actor_alg_type != "ppo": - actor_tau_col, actor_opmd_baseline_col, actor_use_uid_col = st.columns(3) - actor_tau = actor_tau_col.number_input( - "Tau for OPMD", - value=0.0, - min_value=0.0, - max_value=1.0, - format="%.1e", - ) - actor_opmd_baseline = actor_opmd_baseline_col.selectbox( - "OPMD Baseline", - ["mean", "logavgexp"], - index=0, - ) - actor_use_uid = actor_use_uid_col.checkbox("Use UID for OPMD", value=False) - st.session_state.actor_tau = actor_tau - st.session_state.actor_opmd_baseline = actor_opmd_baseline - st.session_state.actor_use_uid = actor_use_uid - else: - actor_tau = st.session_state.actor_tau - actor_opmd_baseline = st.session_state.actor_opmd_baseline - actor_use_uid = st.session_state.actor_use_uid - - actor_grad_clip = actor_grad_clip_col.number_input( - "Grad Clip", value=1.0, min_value=0.0, max_value=1.0 - ) - actor_clip_ratio = actor_clip_ratio_col.number_input( - "Clip Ratio", value=0.2, min_value=0.0, max_value=1.0 - ) - actor_entropy_coeff = actor_entropy_coeff_col.number_input( - "Entropy Coeff", value=0.001, min_value=0.0, max_value=1.0 - ) - - actor_use_kl_loss = st.checkbox("Use KL Loss (True for GRPO)", value=False) - if "actor_kl_loss_coef" not in st.session_state: - st.session_state.actor_kl_loss_coef = 0.001 - st.session_state.actor_kl_loss_type = "low_var_kl" - if actor_use_kl_loss: - actor_kl_loss_coef_col, actor_kl_loss_type_col = st.columns(2) - actor_kl_loss_coef = actor_kl_loss_coef_col.number_input( - "KL Loss Coef", - value=st.session_state.actor_kl_loss_coef, - min_value=0.0, - max_value=1.0, - format="%.1e", - ) - actor_kl_loss_type_candidates = ["kl", "abs", "mse", "low_var_kl"] - actor_kl_loss_type = actor_kl_loss_type_col.selectbox( - "KL Loss Type", - actor_kl_loss_type_candidates, - index=actor_kl_loss_type_candidates.index( - st.session_state.actor_kl_loss_type - ), - ) - st.session_state.actor_kl_loss_coef = actor_kl_loss_coef - st.session_state.actor_kl_loss_type = actor_kl_loss_type - else: - actor_kl_loss_coef = st.session_state.actor_kl_loss_coef - actor_kl_loss_type = st.session_state.actor_kl_loss_type - - actor_checkpoint = st.multiselect( - "Checkpoint", - ["model", "hf_model", "optimizer", "extra"], - default=["model", "hf_model", "optimizer", "extra"], - ) - - with critic_tab: - st.subheader("Critic Model Config") - ( - critic_lr_col, - critic_warmup_style_col, - critic_lr_warmup_steps_ratio_col, - critic_grad_clip_col, - ) = st.columns(4) - critic_lr = critic_lr_col.number_input( - "Learning Rate", - key="Learning Rate for Critic", - value=1e-6, - min_value=1e-7, - max_value=1e-3, - format="%.1e", - ) - critic_warmup_style = critic_warmup_style_col.selectbox( - "LR Warmup Style", - ["constant", "cosine"], - key="LR Warmup Style for Critic", - index=0, - ) - critic_lr_warmup_steps_ratio = critic_lr_warmup_steps_ratio_col.number_input( - "LR Warmup Steps Ratio", - key="LR Warmup Steps Ratio for Critic", - value=0.0, - min_value=0.0, - max_value=1.0, - ) - critic_grad_clip = critic_grad_clip_col.number_input( - "Grad Clip", - key="Grad Clip for Critic", - value=1.0, - min_value=0.0, - max_value=1.0, - ) - - ( - critic_cliprange_value_col, - critic_ppo_micro_batch_size_per_gpu_col, - critic_ulysses_sequence_parallel_size_col, - ) = st.columns(3) - critic_cliprange_value = critic_cliprange_value_col.number_input( - "Cliprange Value", - key="Cliprange Value for Critic", - value=0.5, - min_value=0.0, - max_value=1.0, - ) - critic_ppo_micro_batch_size_per_gpu = ( - critic_ppo_micro_batch_size_per_gpu_col.number_input( - "Micro Batch Size Per GPU for Critic", value=8, min_value=1 - ) - ) - critic_ulysses_sequence_parallel_size = ( - critic_ulysses_sequence_parallel_size_col.number_input( - "Ulysses Sequence Parallel Size", - key="Ulysses Sequence Parallel Size for Critic", - value=1, - min_value=1, - max_value=8, - ) - ) - - rollout_node_num = engine_num * tensor_parallel_size // gpu_per_node - trainer_nnodes = node_num - rollout_node_num - if node_num == 1: - trainer_n_gpus_per_node = gpu_per_node - engine_num * tensor_parallel_size - else: - trainer_n_gpus_per_node = gpu_per_node - - if trainer_type == "verl": - trainer_config = { - "data": { - "tokenizer": None, - "train_files": "placeholder", - "val_files": "placeholder", - "prompt_key": "placeholder", - "max_prompt_length": max_prompt_tokens, - "max_response_length": max_response_tokens, - "train_batch_size": batch_size_per_gpu * gpu_num * repeat_times, - "val_batch_size": None, - "return_raw_input_ids": False, - "return_raw_chat": False, - "shuffle": True, - "filter_overlong_prompts": False, - "truncation": "error", - "image_key": "images", - }, - "actor_rollout_ref": { - "hybrid_engine": True, - "model": { - "path": model_path, - "external_lib": None, - "override_config": {}, - "enable_gradient_checkpointing": enable_gradient_checkpointing, - "use_remove_padding": use_remove_padding, - }, - "actor": { - "strategy": training_strategy, - "ppo_mini_batch_size": batch_size_per_gpu * gpu_num, - "ppo_micro_batch_size_per_gpu": actor_ppo_micro_batch_size_per_gpu, - "use_dynamic_bsz": use_dynamic_bsz, - "ppo_max_token_len_per_gpu": repeat_times - * (max_prompt_tokens + max_response_tokens), - "grad_clip": actor_grad_clip, - "clip_ratio": actor_clip_ratio, - "entropy_coeff": actor_entropy_coeff, - "use_kl_loss": actor_use_kl_loss, - "kl_loss_coef": actor_kl_loss_coef, - "kl_loss_type": actor_kl_loss_type, - "ppo_epochs": 1, # TODO - "shuffle": False, - "ulysses_sequence_parallel_size": actor_ulysses_sequence_parallel_size, - "checkpoint": {"contents": actor_checkpoint}, - "optim": { - "lr": actor_lr, - "lr_warmup_steps_ratio": actor_lr_warmup_steps_ratio, - "warmup_style": actor_warmup_style, - "total_training_steps": -1 - if total_training_steps is None - else total_training_steps, - }, - "fsdp_config": fsdp_config, - "alg_type": actor_alg_type, - "tau": actor_tau, - "opmd_baseline": actor_opmd_baseline, - "use_uid": actor_use_uid, - }, - "ref": { - "fsdp_config": fsdp_config, - "log_prob_micro_batch_size_per_gpu": ref_log_prob_micro_batch_size_per_gpu, - "log_prob_use_dynamic_bsz": "${actor_rollout_ref.actor.use_dynamic_bsz}", - "log_prob_max_token_len_per_gpu": "${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}", - "ulysses_sequence_parallel_size": "${actor_rollout_ref.actor.ulysses_sequence_parallel_size}", - }, - "rollout": { - "name": "vllm", - "temperature": temperature, - "top_k": -1, - "top_p": 1, - "use_fire_sampling": False, - "prompt_length": "${data.max_prompt_length}", - "response_length": "${data.max_response_length}", - "dtype": "bfloat16", - "gpu_memory_utilization": 0.4, - "ignore_eos": False, - "enforce_eager": True, - "free_cache_engine": True, - "load_format": "dummy_dtensor", - "tensor_model_parallel_size": 2, - "max_num_batched_tokens": 8192, - "max_model_len": None, - "max_num_seqs": 1024, - "log_prob_micro_batch_size_per_gpu": 4, - "log_prob_use_dynamic_bsz": "${actor_rollout_ref.actor.use_dynamic_bsz}", - "log_prob_max_token_len_per_gpu": "${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}", - "disable_log_stats": True, - "enable_chunked_prefill": True, - "do_sample": True, - "n": repeat_times, - }, + + self._set_actor_checkpoint() + + with critic_tab: + st.subheader("Critic Model Config") + self._set_configs_with_st_columns( + ["critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"] + ) + + self._set_configs_with_st_columns( + ["critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio"] + ) + + self._set_configs_with_st_columns(["critic_grad_clip", "critic_cliprange_value"]) + + def expert_mode(self): + model_tab, buffer_tab, connector_tab, trainer_tab = st.tabs( + ["Model", "Data", "Explorer and Synchronizer", "Trainer"] + ) + with model_tab: + self._expert_model_part() + + with buffer_tab: + self._expert_buffer_part() + + with connector_tab: + self._expert_connector_part() + + with trainer_tab: + self._expert_trainer_part() + + def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node: int = 8): + balance_batch = "balance_batch" in st.session_state["training_args"] + enable_gradient_checkpointing = ( + "gradient_checkpointing" in st.session_state["training_args"] + ) + use_remove_padding = "remove_padding" in st.session_state["training_args"] + use_dynamic_bsz = "dynamic_bsz" in st.session_state["training_args"] + + if st.session_state["training_strategy"] == "fsdp": + fsdp_config = { + "wrap_policy": {"min_num_params": 0}, + "param_offload": st.session_state["param_offload"], + "optimizer_offload": st.session_state["optimizer_offload"], + "fsdp_size": -1, + } + else: + fsdp_config = {} + + ppo_epochs = 1 # TODO + ppo_max_token_len_per_gpu = st.session_state["repeat_times"] * ( + st.session_state["max_prompt_tokens"] + st.session_state["max_response_tokens"] + ) + + critic_model_path = ( + st.session_state["critic_model_path"].strip() + if st.session_state["critic_model_path"].strip() + else st.session_state["model_path"] + ) + trainer_config = { + "data": { + "tokenizer": None, + "train_files": "placeholder", + "val_files": "placeholder", + "prompt_key": "placeholder", + "max_prompt_length": st.session_state["max_prompt_tokens"], + "max_response_length": st.session_state["max_response_tokens"], + "train_batch_size": st.session_state["task_num_per_batch"] + * st.session_state["repeat_times"], + "val_batch_size": None, + "return_raw_input_ids": False, + "return_raw_chat": False, + "shuffle": True, + "filter_overlong_prompts": False, + "truncation": "error", + "image_key": "images", + }, + "actor_rollout_ref": { + "hybrid_engine": True, + "model": { + "path": st.session_state["model_path"], + "external_lib": None, + "override_config": {}, + "enable_gradient_checkpointing": enable_gradient_checkpointing, + "use_remove_padding": use_remove_padding, }, - "critic": { - "strategy": training_strategy, + "actor": { + "strategy": st.session_state["training_strategy"], + "ppo_mini_batch_size": st.session_state["task_num_per_batch"], + "ppo_micro_batch_size_per_gpu": st.session_state[ + "actor_ppo_micro_batch_size_per_gpu" + ], + "use_dynamic_bsz": use_dynamic_bsz, + "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu, + "grad_clip": st.session_state["actor_grad_clip"], + "clip_ratio": st.session_state["actor_clip_ratio"], + "entropy_coeff": st.session_state["actor_entropy_coeff"], + "use_kl_loss": st.session_state["actor_use_kl_loss"], + "kl_loss_coef": st.session_state["actor_kl_loss_coef"], + "kl_loss_type": st.session_state["actor_kl_loss_type"], + "ppo_epochs": ppo_epochs, + "shuffle": False, + "ulysses_sequence_parallel_size": st.session_state[ + "actor_ulysses_sequence_parallel_size" + ], + "checkpoint": {"contents": st.session_state["actor_checkpoint"]}, "optim": { - "lr": critic_lr, - "lr_warmup_steps_ratio": critic_warmup_style, - "warmup_style": critic_lr_warmup_steps_ratio, + "lr": st.session_state["actor_lr"], + "lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"], + "warmup_style": st.session_state["actor_warmup_style"], "total_training_steps": -1 - if total_training_steps is None - else total_training_steps, - }, - "model": { - "path": critic_model_path, - "tokenizer_path": "${actor_rollout_ref.model.path}", - "override_config": {}, - "external_lib": "${actor_rollout_ref.model.external_lib}", - "enable_gradient_checkpointing": enable_gradient_checkpointing, - "use_remove_padding": use_remove_padding, - "fsdp_config": fsdp_config, + if st.session_state["total_training_steps"] is None + else st.session_state["total_training_steps"], }, - "ppo_mini_batch_size": "${actor_rollout_ref.actor.ppo_mini_batch_size}", - "ppo_micro_batch_size_per_gpu": critic_ppo_micro_batch_size_per_gpu, - "forward_micro_batch_size_per_gpu": "${critic.ppo_micro_batch_size_per_gpu}", - "use_dynamic_bsz": use_dynamic_bsz, - "ppo_max_token_len_per_gpu": repeat_times - * (max_prompt_tokens + max_response_tokens) - * 2, - "forward_max_token_len_per_gpu": "${critic.ppo_max_token_len_per_gpu}", - "ulysses_sequence_parallel_size": critic_ulysses_sequence_parallel_size, - "ppo_epochs": "${actor_rollout_ref.actor.ppo_epochs}", - "shuffle": "${actor_rollout_ref.actor.shuffle}", - "grad_clip": critic_grad_clip, - "cliprange_value": critic_cliprange_value, + "fsdp_config": copy.deepcopy(fsdp_config), + "alg_type": st.session_state["algorithm_type"], + "tau": st.session_state["actor_tau"], + "opmd_baseline": st.session_state["actor_opmd_baseline"], + "use_uid": st.session_state["actor_use_uid"], }, - "reward_model": { - "enable": False, - "strategy": "fsdp", - "model": { - "input_tokenizer": "${actor_rollout_ref.model.path}", - "path": "~/models/FsfairX-LLaMA3-RM-v0.1", - "external_lib": "${actor_rollout_ref.model.external_lib}", - "use_remove_padding": False, - "fsdp_config": { - "min_num_params": 0, - "param_offload": False, - "fsdp_size": -1, - }, - }, - "ulysses_sequence_parallel_size": 1, - "use_dynamic_bsz": "${critic.use_dynamic_bsz}", - "forward_max_token_len_per_gpu": "${critic.forward_max_token_len_per_gpu}", - "reward_manager": "naive", + "ref": { + "fsdp_config": copy.deepcopy(fsdp_config), + "log_prob_micro_batch_size_per_gpu": st.session_state[ + "ref_log_prob_micro_batch_size_per_gpu" + ], + "log_prob_use_dynamic_bsz": use_dynamic_bsz, + "log_prob_max_token_len_per_gpu": ppo_max_token_len_per_gpu, + "ulysses_sequence_parallel_size": st.session_state[ + "actor_ulysses_sequence_parallel_size" + ], }, - "custom_reward_function": {"path": None, "name": "compute_score"}, - "algorithm": { - "gamma": gamma, - "lam": lam, - "adv_estimator": adv_estimator, - "kl_penalty": kl_penalty, - "kl_ctrl": {"type": kl_ctrl_type, "kl_coef": kl_ctrl_coef}, + "rollout": { + "name": "vllm", + "temperature": st.session_state["temperature"], + "top_k": -1, + "top_p": 1, + "use_fire_sampling": False, + "prompt_length": st.session_state["max_prompt_tokens"], + "response_length": st.session_state["max_response_tokens"], + "dtype": "bfloat16", + "gpu_memory_utilization": 0.4, + "ignore_eos": False, + "enforce_eager": True, + "free_cache_engine": True, + "load_format": "dummy_dtensor", + "tensor_model_parallel_size": 2, + "max_num_batched_tokens": 8192, + "max_model_len": None, + "max_num_seqs": 1024, + "log_prob_micro_batch_size_per_gpu": 4, + "log_prob_use_dynamic_bsz": use_dynamic_bsz, + "log_prob_max_token_len_per_gpu": ppo_max_token_len_per_gpu, + "disable_log_stats": True, + "enable_chunked_prefill": True, + "do_sample": True, + "n": st.session_state["repeat_times"], }, - "trainer": { - "balance_batch": balance_batch, - "total_epochs": total_epoch, - "project_name": project, - "experiment_name": name, - "logger": ["wandb"], - "val_generations_to_log_to_wandb": 0, - "nnodes": trainer_nnodes, - "n_gpus_per_node": trainer_n_gpus_per_node, - "save_freq": save_freq, - "resume_mode": resume_mode, - "resume_from_path": resume_from_path, - "test_freq": 100, - "critic_warmup": critic_warmup, - "default_hdfs_dir": default_hdfs_dir, - "remove_previous_ckpt_in_save": remove_previous_ckpt_in_save, - "del_local_ckpt_after_load": del_local_ckpt_after_load, - "default_local_dir": checkpoint_path, - "val_before_train": False, - "sync_freq": sync_iteration_interval, - "max_actor_ckpt_to_keep": max_actor_ckpt_to_keep, - "max_critic_ckpt_to_keep": max_critic_ckpt_to_keep, + }, + "critic": { + "strategy": st.session_state["training_strategy"], + "optim": { + "lr": st.session_state["critic_lr"], + "lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"], + "warmup_style": st.session_state["critic_warmup_style"], + "total_training_steps": -1 + if st.session_state["total_training_steps"] is None + else st.session_state["total_training_steps"], }, - } + "model": { + "path": critic_model_path, + "tokenizer_path": critic_model_path, + "override_config": {}, + "external_lib": None, + "enable_gradient_checkpointing": enable_gradient_checkpointing, + "use_remove_padding": use_remove_padding, + "fsdp_config": copy.deepcopy(fsdp_config), + }, + "ppo_mini_batch_size": st.session_state["task_num_per_batch"], + "ppo_micro_batch_size_per_gpu": st.session_state[ + "critic_ppo_micro_batch_size_per_gpu" + ], + "forward_micro_batch_size_per_gpu": st.session_state[ + "critic_ppo_micro_batch_size_per_gpu" + ], + "use_dynamic_bsz": use_dynamic_bsz, + "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu * 2, + "forward_max_token_len_per_gpu": ppo_max_token_len_per_gpu * 2, + "ulysses_sequence_parallel_size": st.session_state[ + "critic_ulysses_sequence_parallel_size" + ], + "ppo_epochs": ppo_epochs, + "shuffle": False, + "grad_clip": st.session_state["critic_grad_clip"], + "cliprange_value": st.session_state["critic_cliprange_value"], + }, + "reward_model": { + "enable": False, + "strategy": "fsdp", + "model": { + "input_tokenizer": st.session_state["model_path"], + "path": "~/models/FsfairX-LLaMA3-RM-v0.1", + "external_lib": None, + "use_remove_padding": False, + "fsdp_config": { + "min_num_params": 0, + "param_offload": False, + "fsdp_size": -1, + }, + }, + "ulysses_sequence_parallel_size": 1, + "use_dynamic_bsz": use_dynamic_bsz, + "forward_max_token_len_per_gpu": ppo_max_token_len_per_gpu * 2, + "reward_manager": "naive", + }, + "custom_reward_function": {"path": None, "name": "compute_score"}, + "algorithm": { + "gamma": st.session_state["gamma"], + "lam": st.session_state["lam"], + "adv_estimator": st.session_state["adv_estimator"], + "kl_penalty": st.session_state["kl_penalty"], + "kl_ctrl": { + "type": st.session_state["kl_ctrl_type"], + "kl_coef": st.session_state["kl_ctrl_coef"], + }, + }, + "trainer": { + "balance_batch": balance_batch, + "total_epochs": st.session_state["total_epochs"], + "project_name": st.session_state["project"], + "experiment_name": st.session_state["exp_name"], + "logger": ["wandb"], + "val_generations_to_log_to_wandb": 0, + "nnodes": trainer_nnodes, + "n_gpus_per_node": trainer_n_gpus_per_node, + "save_freq": st.session_state["save_freq"], + "resume_mode": st.session_state["resume_mode"], + "resume_from_path": st.session_state["resume_from_path"], + "test_freq": 100, + "critic_warmup": st.session_state["critic_warmup"], + "default_hdfs_dir": st.session_state["default_hdfs_dir"], + "remove_previous_ckpt_in_save": st.session_state["remove_previous_ckpt_in_save"], + "del_local_ckpt_after_load": st.session_state["del_local_ckpt_after_load"], + "default_local_dir": st.session_state["checkpoint_path"], + "val_before_train": False, + "sync_freq": st.session_state["sync_iteration_interval"], + "max_actor_ckpt_to_keep": st.session_state["max_actor_ckpt_to_keep"], + "max_critic_ckpt_to_keep": st.session_state["max_critic_ckpt_to_keep"], + }, + } + return trainer_config + + def generate_config(self): + trainer_nnodes = ( + st.session_state["node_num"] + - st.session_state["engine_num"] + * st.session_state["tensor_parallel_size"] + // st.session_state["gpu_per_node"] + ) + if st.session_state["node_num"] == 1: + trainer_n_gpus_per_node = ( + st.session_state["gpu_per_node"] + - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"] + ) else: - raise ValueError(f"Invalid trainer type: {trainer_type}") + trainer_n_gpus_per_node = st.session_state["gpu_per_node"] - if st.button("Generate Config", disabled=self.unfinished_flag): + db_url = ( + st.session_state["db_url"] + if st.session_state["db_url"].strip() + else f"sqlite:///{os.path.join(st.session_state['checkpoint_path'], '.cache', st.session_state['project'], st.session_state['exp_name'])}/data.db" + ) + sft_storage_type = ( + StorageType.SQL.value + if "://" in st.session_state["sft_warmup_dataset_path"] + else StorageType.FILE.value + ) # TODO + if st.session_state["trainer_type"] == "verl": + trainer_config = self._generate_verl_config( + trainer_nnodes=trainer_nnodes, trainer_n_gpus_per_node=trainer_n_gpus_per_node + ) + else: + raise ValueError(f"Invalid trainer type: {st.session_state['trainer_type']}") + + if len(self.unfinished_fields) > 0: + disable_generate = True + help_messages = ( + f"Please check following fields: `{'`, `'.join(self.unfinished_fields)}`" + ) + else: + disable_generate = False + help_messages = None + if st.button( + "Generate Config", + disabled=disable_generate, + help=help_messages, + ): config = { "data": { - "total_epochs": total_epoch, - "batch_size": batch_size_per_gpu * gpu_num, - "dataset_path": dataset_path, - "default_workflow_type": default_workflow_type, - "default_reward_fn_type": default_reward_fn_type, - "train_split": train_split, - "eval_split": eval_split, + "total_epochs": st.session_state["total_epochs"], + "batch_size": st.session_state["task_num_per_batch"], + "dataset_path": st.session_state["dataset_path"], + "default_workflow_type": st.session_state["default_workflow_type"], + "default_reward_fn_type": st.session_state["default_reward_fn_type"], + "train_split": st.session_state["train_split"], + "eval_split": st.session_state["eval_split"], "format_config": { - "prompt_key": prompt_key, - "response_key": response_key, + "prompt_key": st.session_state["prompt_key"], + "response_key": st.session_state["response_key"], }, }, "model": { - "model_path": model_path, - "max_prompt_tokens": max_prompt_tokens, - "max_response_tokens": max_response_tokens, - "checkpoint_path": checkpoint_path, + "model_path": st.session_state["model_path"], + "max_prompt_tokens": st.session_state["max_prompt_tokens"], + "max_response_tokens": st.session_state["max_response_tokens"], + "checkpoint_path": st.session_state["checkpoint_path"], }, "cluster": { - "node_num": node_num, - "gpu_per_node": gpu_per_node, + "node_num": st.session_state["node_num"], + "gpu_per_node": st.session_state["gpu_per_node"], }, "buffer": { - "storage_type": storage_type, "db_url": db_url, - "read_batch_size": batch_size_per_gpu * gpu_num * repeat_times, - "max_retry_times": max_retry_times, - "max_retry_interval": max_retry_interval, + "read_batch_size": st.session_state["task_num_per_batch"] + * st.session_state["repeat_times"], + "max_retry_times": st.session_state["max_retry_times"], + "max_retry_interval": st.session_state["max_retry_interval"], + "train_dataset": { + "name": "experience_buffer", # TODO + "storage_type": st.session_state["storage_type"], + "algorithm_type": st.session_state["algorithm_type"], + "path": db_url, + }, + "sft_warmup_dataset": { + "name": "sft_warmup_dataset", + "storage_type": sft_storage_type, + "algorithm_type": AlgorithmType.SFT.value, + "path": st.session_state["sft_warmup_dataset_path"], + }, }, "explorer": { - "engine_type": engine_type, - "engine_num": engine_num, - "runner_num": runner_num, - "tensor_parallel_size": tensor_parallel_size, - "enable_prefix_caching": enable_prefix_caching, - "enforce_eager": enforce_eager, - "dtype": dtype, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "seed": seed, - "logprobs": logprobs, - "repeat_times": repeat_times, - "backend": backend, - "max_pending_requests": max_pending_requests, - "max_waiting_steps": max_waiting_steps, + "engine_type": st.session_state["engine_type"], + "engine_num": st.session_state["engine_num"], + "runner_num": st.session_state["runner_num"], + "tensor_parallel_size": st.session_state["tensor_parallel_size"], + "enable_prefix_caching": st.session_state["enable_prefix_caching"], + "enforce_eager": st.session_state["enforce_eager"], + "dtype": st.session_state["dtype"], + "temperature": st.session_state["temperature"], + "top_p": st.session_state["top_p"], + "top_k": st.session_state["top_k"], + "seed": st.session_state["seed"], + "logprobs": st.session_state["logprobs"], + "repeat_times": st.session_state["repeat_times"], + "backend": st.session_state["backend"], + "max_pending_requests": st.session_state["max_pending_requests"], + "max_waiting_steps": st.session_state["max_waiting_steps"], }, "synchronizer": { - "sync_method": sync_method, - "sync_iteration_interval": sync_iteration_interval, + "sync_method": st.session_state["sync_method"], + "sync_iteration_interval": st.session_state["sync_iteration_interval"], }, "trainer": { - "trainer_type": trainer_type, - "trainer_config_path": trainer_config_path, - "sft_warmup_iteration": sft_warmup_iteration, - "eval_interval": eval_interval, + "trainer_type": st.session_state["trainer_type"], + "algorithm_type": st.session_state["algorithm_type"], + "trainer_config": trainer_config, + "sft_warmup_iteration": st.session_state["sft_warmup_iteration"], + "eval_interval": st.session_state["eval_interval"], }, "monitor": { - "project": project, - "name": name, + "project": st.session_state["project"], + "name": st.session_state["exp_name"], + "monitor_type": st.session_state["monitor_type"], }, } st.header("Generated Config File") - st.subheader("Overall Config File") + st.subheader("Config File") yaml_config = yaml.dump(config, allow_unicode=True, sort_keys=False) st.code(yaml_config, language="yaml") - st.subheader("Trainer Config File") - trainer_config = yaml.dump(trainer_config, allow_unicode=True, sort_keys=False) - st.code(trainer_config, language="yaml") - - def main(self): - mode = st.pills( - "Select Mode", - options=["Beginer Mode", "Expert Mode"], - default="Expert Mode", - label_visibility="collapsed", - ) - if mode == "Beginer Mode": - self.beginer_mode() - else: - self.expert_mode() if __name__ == "__main__": config_manager = ConfigManager() - config_manager.main() diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 94f5939f53..0273fd0544 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -928,17 +928,6 @@ def _build_critic_model_optimizer(self, config): ) critic_model_config.num_labels = 1 - use_remove_padding = config.model.get("use_remove_padding", False) - if use_remove_padding: - from verl.models.registry import check_model_support_rmpad - - check_model_support_rmpad(critic_model_config.model_type) - - if use_remove_padding and self.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch(critic_model_config, verbose=True) - init_context = get_init_weight_context_manager( use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh )