diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
index fb9a82873b..d7528fa0df 100644
--- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
+++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
@@ -42,7 +42,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
- total_epoch: 1
+ total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'
```
@@ -53,7 +53,7 @@ Here you can set the basic information for the GSM-8K dataset, database informat
+ `dataset_config`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library.
+ `format_config`: some dataset format config items, which are used to map original data field names to unified ones.
+ `db_url`: the URL of the postgresql database to store the result dataset.
-+ `total_epoch`: the total number of epochs to train on this dataset.
++ `total_epochs`: the total number of epochs to train on this dataset.
+ `batch_size`: the training batch size.
+ `default_workflow_type`: the default exploring workflow type. Please refer to [programming guide](trinity_programming_guide.md) for more details.
@@ -74,7 +74,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
- total_epoch: 1
+ total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'
@@ -120,7 +120,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
- total_epoch: 1
+ total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'
@@ -199,7 +199,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
- total_epoch: 20
+ total_epochs: 20
batch_size: 32
default_workflow_type: 'math_workflow'
```
diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 0ef8e93db5..0be1153630 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -3,6 +3,18 @@
The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example.
+## Monitor
+
+```yaml
+monitor:
+ project: "Trinity-RFT-countdown"
+ name: "qwen2.5-1.5B-countdown"
+```
+
+- `monitor.project`: The project name. It must be set manually.
+- `monitor.name`: The name of the experiment. It must be set manually.
+
+
## Monitor
```yaml
@@ -33,7 +45,7 @@ data:
max_retry_times: 3
max_retry_interval: 1
- total_epoch: 20
+ total_epochs: 20
batch_size: 96
default_workflow_type: 'math_workflow'
default_reward_fn_type: 'countdown_reward'
@@ -47,7 +59,7 @@ data:
- `data.db_url`: The URL of the database.
- `data.max_retry_times`: The maximum number of retries when loading the dataset from database.
- `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database.
-- `data.total_epoch`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
+- `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `actor_rollout_ref.rollout.n` Default is `1`. It should be set manually.
- `data.default_workflow_type`: The default workflow type used for training.
- `data.default_reward_fn_type`: The default reward function type used for training.
@@ -345,10 +357,14 @@ algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
+ norm_adv_by_std_in_grpo: True
+ use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
+ horizon: 10000
+ target_kl: 0.1
trainer:
balance_batch: True
@@ -363,7 +379,7 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
+ resume_from_path: ""
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
@@ -383,8 +399,9 @@ trainer:
- `actor_rollout_ref.actor.grad_clip`: Gradient clip for actor model training.
- `actor_rollout_ref.actor.clip_ratio`: Used for compute policy loss.
- `actor_rollout_ref.actor.entropy_coeff`: Used for compute policy loss.
-- `actor_rollout_ref.actor.use_kl_loss`: True for GRPO.
-- `actor_rollout_ref.actor.kl_loss_coef`: Used for GRPO, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
+- `actor_rollout_ref.actor.use_kl_loss`: Whether to enable kl loss.
+- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss.
+- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
- `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size.
- `actor_rollout_ref.actor.alg_type`: Used for OPMD, optional value is `ppo`, `opmd` or `pairwise_opmd`.
- `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy.
diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml
index 991c662138..29c49bb1a0 100644
--- a/examples/dpo_humanlike/dpo.yaml
+++ b/examples/dpo_humanlike/dpo.yaml
@@ -1,6 +1,6 @@
mode: train
data:
- total_epoch: 20
+ total_epochs: 20
batch_size: 32 # NOTE
train_split: "train"
dataset_path: ''
@@ -22,7 +22,6 @@ buffer:
train_dataset:
name: dpo_buffer
storage_type: file
- algorithm_type: dpo
path: '/PATH/TO/DATASET/'
kwargs:
prompt_type: plaintext # plaintext/messages
diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml
index ae7689106b..4da9a7ddb5 100644
--- a/examples/dpo_humanlike/train_dpo.yaml
+++ b/examples/dpo_humanlike/train_dpo.yaml
@@ -173,7 +173,6 @@ trainer:
save_freq: 30
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml
index 100420acd8..0258bfdbf0 100644
--- a/examples/grpo_alfworld/alfworld.yaml
+++ b/examples/grpo_alfworld/alfworld.yaml
@@ -1,5 +1,5 @@
data:
- total_epoch: 20
+ total_epochs: 20
batch_size: 4
dataset_path: 'scripts/data_prepare/alfworld_data'
default_workflow_type: 'alfworld_workflow'
@@ -21,7 +21,6 @@ buffer:
train_dataset:
name: alfworld_buffer
storage_type: queue
- algorithm_type: ppo
path: 'sqlite:///alfworld.db'
explorer:
engine_type: vllm_async
diff --git a/examples/grpo_alfworld/train_alfworld.yaml b/examples/grpo_alfworld/train_alfworld.yaml
index 0a1c109754..88f151bdcb 100644
--- a/examples/grpo_alfworld/train_alfworld.yaml
+++ b/examples/grpo_alfworld/train_alfworld.yaml
@@ -172,7 +172,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml
index 677027ea19..63850d5d24 100644
--- a/examples/grpo_gsm8k/gsm8k.yaml
+++ b/examples/grpo_gsm8k/gsm8k.yaml
@@ -18,7 +18,7 @@ data:
# db related
db_url: ''
# downstream loading related
- total_epoch: 1
+ total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'
model:
@@ -35,12 +35,10 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
- algorithm_type: ppo
path: 'sqlite:///gsm8k.db'
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
- # algorithm_type: sft
# path: '/PATH/TO/WARMUP_DATA/'
# kwargs:
# prompt_type: plaintext
diff --git a/examples/grpo_gsm8k/train_gsm8k.yaml b/examples/grpo_gsm8k/train_gsm8k.yaml
index eeb64dc746..2e8365c6cb 100644
--- a/examples/grpo_gsm8k/train_gsm8k.yaml
+++ b/examples/grpo_gsm8k/train_gsm8k.yaml
@@ -177,7 +177,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml
index ae7e40d9d0..07ea448548 100644
--- a/examples/grpo_math/math.yaml
+++ b/examples/grpo_math/math.yaml
@@ -10,7 +10,7 @@ data:
# db related
db_url: ''
# downstream loading related
- total_epoch: 20
+ total_epochs: 20
batch_size: 288
default_workflow_type: 'math_workflow'
model:
@@ -27,8 +27,7 @@ buffer:
train_dataset:
name: math_buffer
storage_type: queue
- algorithm_type: ppo
- path: 'sqlite:////math.db'
+ path: 'sqlite:///math.db'
explorer:
engine_type: vllm_async
engine_num: 2
diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml
index c49fda58ce..937c19657e 100644
--- a/examples/grpo_math/train_math.yaml
+++ b/examples/grpo_math/train_math.yaml
@@ -169,7 +169,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml
index d036fc54fc..25b5dfa073 100644
--- a/examples/grpo_sciworld/sciworld.yaml
+++ b/examples/grpo_sciworld/sciworld.yaml
@@ -1,5 +1,5 @@
data:
- total_epoch: 20
+ total_epochs: 20
batch_size: 4
dataset_path: 'scripts/data_prepare/sciworld_data'
default_workflow_type: 'sciworld_workflow'
@@ -21,7 +21,6 @@ buffer:
train_dataset:
name: sciworld_buffer
storage_type: queue
- algorithm_type: ppo
path: 'sqlite:///sciworld.db'
explorer:
engine_type: vllm_async
diff --git a/examples/grpo_sciworld/train_sciworld.yaml b/examples/grpo_sciworld/train_sciworld.yaml
index 880fc61fcc..330b659afb 100644
--- a/examples/grpo_sciworld/train_sciworld.yaml
+++ b/examples/grpo_sciworld/train_sciworld.yaml
@@ -167,7 +167,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
diff --git a/examples/grpo_webshop/train_webshop.yaml b/examples/grpo_webshop/train_webshop.yaml
index aac06b7043..0ae8675f50 100644
--- a/examples/grpo_webshop/train_webshop.yaml
+++ b/examples/grpo_webshop/train_webshop.yaml
@@ -172,7 +172,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml
index ff5c2c57bf..eb9916018d 100644
--- a/examples/grpo_webshop/webshop.yaml
+++ b/examples/grpo_webshop/webshop.yaml
@@ -1,5 +1,5 @@
data:
- total_epoch: 20
+ total_epochs: 20
batch_size: 4
dataset_path: 'scripts/data_prepare/webshop_data'
default_workflow_type: 'webshop_workflow'
@@ -21,7 +21,6 @@ buffer:
train_dataset:
name: webshop_buffer
storage_type: queue
- algorithm_type: ppo
path: 'sqlite:///webshop.db'
explorer:
engine_type: vllm_async
diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml
index 5a458b24e8..6cde601158 100644
--- a/examples/opmd_gsm8k/opmd_gsm8k.yaml
+++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml
@@ -1,5 +1,5 @@
data:
- total_epoch: 1
+ total_epochs: 1
batch_size: 96
dataset_path: '{path to datasets}/gsm8k'
default_workflow_type: 'math_workflow'
@@ -20,7 +20,6 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
- algorithm_type: opmd
path: 'sqlite:///gsm8k_opmd.db'
explorer:
engine_type: vllm_async
diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
index a82f4d0739..97384e57c3 100644
--- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
+++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
@@ -204,7 +204,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml
index a531eebda0..9282a7d1a0 100644
--- a/examples/ppo_countdown/countdown.yaml
+++ b/examples/ppo_countdown/countdown.yaml
@@ -1,5 +1,5 @@
data:
- total_epoch: 20
+ total_epochs: 20
batch_size: 96
dataset_path: 'countdown_dataset/oneshot-split'
default_workflow_type: 'math_workflow'
@@ -23,8 +23,7 @@ buffer:
train_dataset:
name: countdown_buffer
storage_type: queue
- algorithm_type: ppo
- path: 'sqlite:////countdown.db'
+ path: 'sqlite:///countdown.db'
explorer:
engine_type: vllm_async
engine_num: 2
diff --git a/examples/ppo_countdown/train_countdown.yaml b/examples/ppo_countdown/train_countdown.yaml
index 53a1401c4b..70872fd7f1 100644
--- a/examples/ppo_countdown/train_countdown.yaml
+++ b/examples/ppo_countdown/train_countdown.yaml
@@ -179,7 +179,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
- resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
diff --git a/pyproject.toml b/pyproject.toml
index 717a88a00c..8a6f226c77 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,6 +32,7 @@ dependencies = [
"math_verify",
"ninja",
"fire",
+ "streamlit",
"flask",
"requests",
"tensorboard",
diff --git a/tests/common/tmp/template_config.yaml b/tests/common/tmp/template_config.yaml
index 2cae62acdf..e163623680 100644
--- a/tests/common/tmp/template_config.yaml
+++ b/tests/common/tmp/template_config.yaml
@@ -1,7 +1,7 @@
mode: both
data:
dataset_path: ''
- total_epoch: 1
+ total_epochs: 1
batch_size: 32
train_split: 'train'
eval_split: ''
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 8fdd9f8e27..587c7c5fe8 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -69,7 +69,7 @@ class DataConfig:
max_retry_interval: int = 1
# downstream loading related
- total_epoch: int = 1
+ total_epochs: int = 1
batch_size: int = 1
default_workflow_type: str = ""
default_reward_fn_type: str = ""
@@ -101,7 +101,7 @@ class DatasetConfig:
name: str
storage_type: StorageType
- algorithm_type: AlgorithmType
+ algorithm_type: AlgorithmType = AlgorithmType.PPO
path: Optional[str] = None
kwargs: Dict[str, Any] = field(default_factory=dict)
@@ -143,7 +143,7 @@ class ExplorerConfig:
# For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num`
runner_num: int = 1
- # repeat each task for `repeat_times` times (for GPRO-like algrorithms)
+ # repeat each task for `repeat_times` times (for GPRO-like algorithms)
repeat_times: int = 1
# for rollout tokneize
@@ -177,7 +177,7 @@ class TrainerConfig:
trainer_config_path: str = ""
eval_interval: int = 100
enable_preview: bool = True # enable rollout preview in wandb
- trainer_config: Any = None
+ trainer_config: Any = field(default_factory=dict)
# train algorithm
algorithm_type: AlgorithmType = AlgorithmType.PPO
@@ -266,21 +266,29 @@ def _check_buffer(self) -> None:
else:
if self.buffer.train_dataset is None:
raise ValueError("buffer.train_dataset is required when mode is not 'both'")
- if self.buffer.train_dataset.algorithm_type != self.trainer.algorithm_type:
- raise ValueError(
- f"buffer.train_dataset.algorithm_type ({self.buffer.train_dataset.algorithm_type}) "
- f"is not consistent with trainer.algorithm_type ({self.trainer.algorithm_type})"
- )
+ self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type
+ if self.buffer.sft_warmup_dataset is not None:
+ self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times
def check_and_update(self) -> None:
"""Check and update the config."""
if self.trainer.trainer_type == "verl":
- from trinity.common.verl_config import load_config
-
- if not os.path.isfile(self.trainer.trainer_config_path):
- raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}")
- self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
+ if self.trainer.trainer_config:
+ from trinity.common.verl_config import veRLConfig
+
+ trainer_config_schema = OmegaConf.structured(veRLConfig)
+ trainer_config = OmegaConf.merge(trainer_config_schema, self.trainer.trainer_config)
+ self.trainer.trainer_config = OmegaConf.to_object(trainer_config)
+ else:
+ if os.path.isfile(self.trainer.trainer_config_path):
+ from trinity.common.verl_config import load_config
+
+ self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
+ else:
+ raise ValueError(
+ f"Invalid trainer config path: {self.trainer.trainer_config_path}"
+ )
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
diff --git a/trinity/common/task.py b/trinity/common/task.py
index bef638b6c9..5f5309565e 100644
--- a/trinity/common/task.py
+++ b/trinity/common/task.py
@@ -121,7 +121,7 @@ class TaskSet:
task_type: Optional[TaskType] = None
default_index: int = 0
default_epoch: int = 0
- total_epoch: int = 1
+ total_epochs: int = 1
_tasks: Iterator[Task] = None
_index: int = 0
_epoch: int = 0
@@ -160,7 +160,7 @@ def load(
task_type=task_type,
default_index=latest_task_index % dataset_len,
default_epoch=latest_task_index // dataset_len,
- total_epoch=config.total_epoch if task_type == TaskType.EXPLORE else 1,
+ total_epochs=config.total_epochs if task_type == TaskType.EXPLORE else 1,
)
def __iter__(self) -> Iterator[Task]:
@@ -189,7 +189,7 @@ def epoch(self) -> int:
def __next__(self) -> Task:
"""Iterate through the tasks in the taskset."""
- if self._epoch >= self.total_epoch:
+ if self._epoch >= self.total_epochs:
raise StopIteration
try:
@@ -204,7 +204,7 @@ def __next__(self) -> Task:
# Reset the task generator and increment the epoch
self._epoch += 1
self._index += 1
- if self._epoch >= self.total_epoch:
+ if self._epoch >= self.total_epochs:
raise StopIteration
self._tasks = task_generator(
self.dataset,
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index a9d7dd6cb4..966fde0391 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -83,13 +83,13 @@ class Actor:
ppo_epochs: int = 1
shuffle: bool = False
ulysses_sequence_parallel_size: int = 1
+ checkpoint: Checkpoint = field(default_factory=Checkpoint)
optim: Optim = field(default_factory=Optim)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
alg_type: str = "ppo" # ppo / opmd / pairwise_opmd
tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
use_uid: bool = False # True / False, applicable to pairwise_opmd
- checkpoint: Checkpoint = field(default_factory=Checkpoint)
@dataclass
@@ -205,6 +205,8 @@ class CustomRewardFunction:
class KL_Ctrl:
type: str = "fixed"
kl_coef: float = 0.001
+ horizon: float = 10000
+ target_kl: float = 0.1
@dataclass
@@ -212,6 +214,8 @@ class Algorithm:
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
+ norm_adv_by_std_in_grpo: bool = True
+ use_kl_in_reward: bool = False
kl_penalty: str = "kl"
kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)
@@ -229,7 +233,7 @@ class Trainer:
n_gpus_per_node: int = 0
save_freq: int = 0
resume_mode: str = "auto"
- resume_from_path: bool = False
+ resume_from_path: str = ""
test_freq: int = 0
critic_warmup: int = 0
default_hdfs_dir: Optional[str] = None
@@ -300,8 +304,10 @@ def synchronize_config(self, config: Config) -> None:
self.actor_rollout_ref.rollout.temperature = config.explorer.temperature
self.actor_rollout_ref.rollout.n = config.explorer.repeat_times
batch_size_per_gpu = self.buffer.read_batch_size // world_size
- self.actor_rollout_ref.actor.alg_type = config.trainer.algorithm_type.value
- print(f"using algorithm type: {self.actor_rollout_ref.actor.alg_type}")
+ self.actor_rollout_ref.actor.alg_type = (
+ config.trainer.algorithm_type.value
+ ) # TODO: refactor `alg_type`
+ # print(f"using algorithm type: {self.actor_rollout_ref.actor.alg_type}")
if self.actor_rollout_ref.actor.alg_type == "dpo": # for DPO
print("Warning: DPO micro batch size is doubled for computing loss.")
diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py
index c4436705ab..13ade4b011 100644
--- a/trinity/manager/config_manager.py
+++ b/trinity/manager/config_manager.py
@@ -1,838 +1,1353 @@
+import copy
import os
+from typing import List
import streamlit as st
import yaml
-from verl.trainer.ppo.ray_trainer import AdvantageEstimator
+from trinity.common.constants import AlgorithmType, MonitorType, StorageType
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows.workflow import WORKFLOWS
+from trinity.trainer.verl.ray_trainer import AdvantageEstimator
class ConfigManager:
def __init__(self):
+ self._init_default_config()
+ self.unfinished_fields = set()
st.set_page_config(page_title="Trainer Config Generator", page_icon=":robot:")
st.title("Trainer Config Generator")
- self.reset_config()
- self.unfinished_flag = False
+ if "_init_config_manager" not in st.session_state:
+ self.reset_session_state()
+ self.maintain_session_state()
+ mode = st.pills(
+ "Select Mode",
+ options=["Beginner Mode", "Expert Mode"],
+ default="Beginner Mode",
+ label_visibility="collapsed",
+ )
+ if mode == "Beginner Mode":
+ self.beginner_mode()
+ else:
+ self.expert_mode()
+ self.generate_config()
- def reset_config(self):
- pass
+ def _init_default_config(self):
+ self.default_config = {
+ "_init_config_manager": True,
+ "project": "Trinity-RFT",
+ "exp_name": "qwen2.5-1.5B",
+ "monitor_type": MonitorType.WANDB.value,
+ # Model Configs
+ "model_path": "",
+ "critic_model_path": "",
+ "checkpoint_path": "",
+ "node_num": 1,
+ "gpu_per_node": 8,
+ "total_gpu_num": 8,
+ "trainer_gpu_num": 6,
+ "max_prompt_tokens": 1024,
+ "max_response_tokens": 1024,
+ # Data and Buffer Configs
+ "total_epochs": 20,
+ "task_num_per_batch": 6,
+ "dataset_path": "",
+ "subset_name": None,
+ "train_split": "train",
+ "eval_split": "",
+ "prompt_key": "question",
+ "response_key": "answer",
+ "default_workflow_type": "math_workflow",
+ "default_reward_fn_type": "math_reward",
+ "storage_type": StorageType.QUEUE.value,
+ "db_url": "",
+ "max_retry_times": 3,
+ "max_retry_interval": 1,
+ "sft_warmup_dataset_path": "",
+ "sft_warmup_train_split": "train",
+ "sft_warmup_eval_split": "",
+ "sft_warmup_prompt_key": "question",
+ "sft_warmup_response_key": "answer",
+ # Explorer and Sync Configs
+ "engine_type": "vllm_async",
+ "engine_num": 2,
+ "tensor_parallel_size": 1,
+ "repeat_times": 1,
+ "sync_method": "online",
+ "sync_iteration_interval": 10,
+ "runner_num": 32,
+ "max_pending_requests": 32,
+ "max_waiting_steps": 4,
+ "dtype": "bfloat16",
+ "backend": "nccl",
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "top_k": -1,
+ "seed": 42,
+ "logprobs": 0,
+ "enable_prefix_caching": False,
+ "enforce_eager": True,
+ # Trainer Configs
+ "trainer_type": "verl",
+ "algorithm_type": AlgorithmType.PPO.value,
+ "sft_warmup_iteration": 0,
+ "eval_interval": 1000,
+ # veRL Trainer Configs
+ "training_args": [
+ "balance_batch",
+ "gradient_checkpointing",
+ "remove_padding",
+ "dynamic_bsz",
+ ],
+ "save_freq": 100,
+ "training_strategy": "fsdp",
+ "param_offload": False,
+ "optimizer_offload": False,
+ "resume_mode": "auto",
+ "resume_from_path": "",
+ "critic_warmup": 0,
+ "total_training_steps": None,
+ "default_hdfs_dir": None,
+ "remove_previous_ckpt_in_save": False,
+ "del_local_ckpt_after_load": False,
+ "max_actor_ckpt_to_keep": None,
+ "max_critic_ckpt_to_keep": None,
+ "gamma": 1.0,
+ "lam": 1.0,
+ "adv_estimator": "gae",
+ "norm_adv_by_std_in_grpo": True,
+ "use_kl_in_reward": False,
+ "kl_penalty": "low_var_kl",
+ "kl_ctrl_type": "fixed",
+ "kl_ctrl_coef": 0.001,
+ "horizon": 10000,
+ "target_kl": 0.1,
+ "actor_ppo_micro_batch_size_per_gpu": 4,
+ "ref_log_prob_micro_batch_size_per_gpu": 8,
+ "actor_ulysses_sequence_parallel_size": 1,
+ "actor_lr": 1e-6,
+ "actor_warmup_style": "constant",
+ "actor_lr_warmup_steps_ratio": 0.0,
+ "actor_tau": 0.0,
+ "actor_opmd_baseline": "mean",
+ "actor_use_uid": False,
+ "actor_grad_clip": 1.0,
+ "actor_clip_ratio": 0.2,
+ "actor_entropy_coeff": 0.001,
+ "actor_use_kl_loss": True,
+ "actor_kl_loss_coef": 0.001,
+ "actor_kl_loss_type": "low_var_kl",
+ "actor_checkpoint": ["model", "hf_model", "optimizer", "extra"],
+ "critic_lr": 1e-6,
+ "critic_warmup_style": "constant",
+ "critic_lr_warmup_steps_ratio": 0.0,
+ "critic_grad_clip": 1.0,
+ "critic_cliprange_value": 0.5,
+ "critic_ppo_micro_batch_size_per_gpu": 8,
+ "critic_ulysses_sequence_parallel_size": 1,
+ "training_mode": "PPO",
+ }
- def set_value(self, key, value):
- st.session_state[key] = value
+ def reset_session_state(self):
+ for key, value in self.default_config.items():
+ st.session_state[key] = value
- def beginer_mode(self):
- st.write("Work in progress...")
+ def maintain_session_state(self):
+ for key in self.default_config:
+ st.session_state[key] = st.session_state[key]
- def expert_mode(self): # noqa: C901
- model_tab, buffer_tab, connector_tab, trainer_tab = st.tabs(
- ["Model", "Buffer", "Explorer and Synchronizer", "Trainer"]
+ def _set_project(self):
+ st.text_input("Project", key="project")
+
+ def _set_name(self):
+ st.text_input("Experiment Name", key="exp_name")
+
+ def _set_monitor_type(self):
+ st.selectbox(
+ "Monitor Type",
+ options=[monitor_type.value for monitor_type in MonitorType],
+ key="monitor_type",
)
- with model_tab:
- project_col, name_col = st.columns([1, 3])
- project = project_col.text_input("Project", "Trinity-RFT")
- name = name_col.text_input("Experiment Name", "qwen2.5-1.5B")
-
- model_path = st.text_input("Model Path", "")
- if not model_path.strip():
- self.unfinished_flag = True
- st.warning("Please input model path")
- critic_model_path = st.text_input("Critic Model Path (defaults to `model_path`)", "")
- if not critic_model_path.strip():
- critic_model_path = model_path
-
- checkpoint_path = st.text_input("Checkpoint Path", "")
- if not checkpoint_path.strip():
- self.unfinished_flag = True
- st.warning("Please input checkpoint path")
+ def _set_model_path(self):
+ st.text_input("Model Path", key="model_path")
+ if not st.session_state["model_path"].strip():
+ self.unfinished_fields.add("model_path")
+ st.warning("Please input model path.")
+
+ def _set_critic_model_path(self):
+ st.text_input(
+ "Critic Model Path (defaults to `model_path`)",
+ key="critic_model_path",
+ )
+
+ def _set_checkpoint_path(self):
+ st.text_input("Checkpoint Path", key="checkpoint_path")
+ if not st.session_state["checkpoint_path"].strip(): # TODO: may auto generate
+ self.unfinished_fields.add("checkpoint_path")
+ st.warning("Please input checkpoint path.")
+ elif not os.path.isabs(st.session_state["checkpoint_path"].strip()):
+ self.unfinished_fields.add("checkpoint_path")
+ st.warning("Please input an absolute path.")
+
+ def _set_node_num(self):
+ st.number_input("Node Num", key="node_num", min_value=1, on_change=self._set_total_gpu_num)
+
+ def _set_gpu_per_node(self):
+ st.number_input(
+ "GPU Per Node",
+ key="gpu_per_node",
+ min_value=1,
+ max_value=8,
+ on_change=self._set_total_gpu_num,
+ )
+
+ def _set_total_gpu_num(self):
+ st.session_state["total_gpu_num"] = (
+ st.session_state["gpu_per_node"] * st.session_state["node_num"]
+ )
+ self._set_trainer_gpu_num()
+
+ def _set_trainer_gpu_num(self):
+ st.session_state["trainer_gpu_num"] = (
+ st.session_state["total_gpu_num"]
+ - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"]
+ )
+
+ def _set_max_prompt_tokens(self):
+ st.number_input("Max Prompt Tokens", key="max_prompt_tokens", min_value=1)
+
+ def _set_max_response_tokens(self):
+ st.number_input("Max Response Tokens", key="max_response_tokens", min_value=1)
+
+ def _set_total_epochs(self):
+ st.number_input("Total Epochs", key="total_epochs", min_value=1)
+
+ @property
+ def _str_for_task_num_per_batch(self):
+ return (
+ f"Please ensure that `task_num_per_batch` can be divided by "
+ f"`gpu_per_node * node_num - engine_num * tensor_parallel_size` "
+ f"= {st.session_state['trainer_gpu_num']}"
+ )
+
+ def _set_task_num_per_batch(self):
+ trainer_gpu_num = st.session_state["trainer_gpu_num"]
+ if st.session_state["task_num_per_batch"] < trainer_gpu_num:
+ st.session_state["task_num_per_batch"] = trainer_gpu_num
+ st.number_input(
+ "Task Num Per Batch",
+ key="task_num_per_batch",
+ min_value=trainer_gpu_num,
+ step=trainer_gpu_num,
+ help=self._str_for_task_num_per_batch,
+ )
+
+ def _check_task_num_per_batch(self):
+ if st.session_state["task_num_per_batch"] % st.session_state["trainer_gpu_num"] != 0:
+ self.unfinished_fields.add("task_num_per_batch")
+ st.warning(self._str_for_task_num_per_batch)
+
+ def _set_dataset_path(self):
+ st.text_input("Dataset Path", key="dataset_path")
+ if not st.session_state["dataset_path"].strip():
+ self.unfinished_fields.add("dataset_path")
+ st.warning("Please input dataset path.")
+
+ def _set_dataset_args(self):
+ if st.session_state["dataset_path"] and "://" not in st.session_state["dataset_path"]:
+ subset_name_col, train_split_col, eval_split_col = st.columns(3)
+ subset_name_col.text_input("Subset Name", key="subset_name")
+ train_split_col.text_input("Train Split", key="train_split")
+ eval_split_col.text_input("Eval Split", key="eval_split")
+ prompt_key_col, response_key_col = st.columns(2)
+ prompt_key_col.text_input("Prompt Key", key="prompt_key")
+ response_key_col.text_input("Response Key", key="response_key")
+
+ def _set_default_workflow_type(self):
+ st.selectbox(
+ "Default Workflow Type",
+ WORKFLOWS.modules.keys(),
+ key="default_workflow_type",
+ help=r"""`simple_workflow`: call 'model.chat()' to get responses.
+
+`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses.
+
+Other workflows: conduct multi-turn task for the given dataset.
+""",
+ )
+
+ def _set_default_reward_fn_type(self):
+ st.selectbox(
+ "Default Reward Fn Type",
+ REWARD_FUNCTIONS.modules.keys(),
+ key="default_reward_fn_type",
+ help=r"""`accuracy_reward`: check the accuracy for math problems.
+
+`format_reward`: check if the response matches the format (default: `** *`).
+
+`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1).
+""",
+ )
+
+ def _set_storage_type(self):
+ st.selectbox(
+ "Storage Type",
+ [storage_type.value for storage_type in StorageType],
+ key="storage_type",
+ )
+
+ def _set_db_url(self):
+ st.text_input(
+ "DB URL",
+ key="db_url",
+ help=r"Default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project_name, experiment_name)}/data.db`",
+ )
+
+ def _set_max_retry_times(self):
+ st.number_input("Max Retry Times", key="max_retry_times", min_value=1)
+
+ def _set_max_retry_interval(self):
+ st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1)
+
+ def _check_sft_warmup_dataset_path(self):
+ if st.session_state["sft_warmup_iteration"]:
+ if not st.session_state["sft_warmup_dataset_path"].strip():
+ self.unfinished_fields.add("sft_warmup_dataset_path")
+ st.warning(
+ "Please input SFT warmup dataset path when `sft_warmup_iteration` is not 0"
+ )
+
+ def _set_sft_warmup_dataset_path(self):
+ st.text_input("SFT Warmup Dataset Path", key="sft_warmup_dataset_path")
+ self._check_sft_warmup_dataset_path()
+
+ def _set_sft_warmup_dataset_args(self):
+ if (
+ st.session_state["sft_warmup_dataset_path"]
+ and "://" not in st.session_state["sft_warmup_dataset_path"]
+ ): # TODO
(
- node_num_col,
- gpu_per_node_col,
- max_prompt_tokens_col,
- max_response_tokens_col,
+ sft_warmup_train_split_col,
+ sft_warmup_eval_split_col,
+ sft_warmup_prompt_key_col,
+ sft_warmup_response_key_col,
) = st.columns(4)
- node_num = node_num_col.number_input("Node Num", value=1, min_value=1)
- gpu_per_node = gpu_per_node_col.number_input(
- "GPU Per Node", value=8, min_value=1, max_value=8
- )
- max_prompt_tokens = max_prompt_tokens_col.number_input(
- "Max Prompt Tokens", value=256, min_value=1
- )
- max_response_tokens = max_response_tokens_col.number_input(
- "Max Response Tokens", value=1024, min_value=1
+ sft_warmup_train_split_col.text_input("SFT Train Split", key="sft_warmup_train_split")
+ sft_warmup_eval_split_col.text_input("SFT Eval Split", key="sft_warmup_eval_split")
+ sft_warmup_prompt_key_col.text_input("SFT Prompt Key", key="sft_warmup_prompt_key")
+ sft_warmup_response_key_col.text_input(
+ "SFT Response Key", key="sft_warmup_response_key"
)
- with buffer_tab:
- total_epoch_col, batch_size_per_gpu_col = st.columns(2)
- total_epoch = total_epoch_col.number_input("Total Epoch", value=20, min_value=1)
- batch_size_per_gpu = batch_size_per_gpu_col.number_input(
- "Batch Size Per GPU", value=1, min_value=1
- )
+ def _set_engine_type(self):
+ st.selectbox("Explorer Engine Type", ["vllm_async", "vllm"], key="engine_type")
- dataset_path = st.text_input("Dataset Path", "")
- if not dataset_path.strip():
- self.unfinished_flag = True
- st.warning("Please input dataset path")
-
- if dataset_path and "://" not in dataset_path:
- train_split_col, eval_split_col, prompt_key_col, response_key_col = st.columns(4)
- train_split = train_split_col.text_input("Train Split", "train")
- eval_split = eval_split_col.text_input("Eval Split", "")
- prompt_key = prompt_key_col.text_input("Prompt Key", "question")
- response_key = response_key_col.text_input("Response Key", "answer")
-
- default_workflow_type_col, default_reward_fn_type_col, storage_type_col = st.columns(3)
- default_workflow_type = default_workflow_type_col.selectbox(
- "Default Workflow Type", WORKFLOWS.modules.keys(), index=1
- )
- default_reward_fn_type = default_reward_fn_type_col.selectbox(
- "Default Reward Fn Type", REWARD_FUNCTIONS.modules.keys(), index=3
- )
- storage_type = storage_type_col.selectbox(
- "Storage Type", ["sql", "redis", "queue"], index=2
- )
+ @property
+ def _str_for_engine_num_and_tp_size(self):
+ return r"""and it must meet the following constraints:
+```python
+assert engine_num * tensor_parallel_size < gpu_per_node * node_num
+if node_num > 1:
+ assert gpu_per_node % tensor_parallel_size == 0
+ assert engine_num * tensor_parallel_size % gpu_per_node == 0
+```"""
- buffer_advanced_tab = st.expander("Advanced Config")
- with buffer_advanced_tab:
- db_url = st.text_input(
- "DB URL",
- "",
- help=r"Default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project, name)}/data.db`",
- )
- if not db_url.strip():
- db_url = rf"sqlite:///{os.path.join(checkpoint_path, '.cache', project, name)}/data.db"
+ def _set_engine_num(self):
+ total_gpu_num = st.session_state["gpu_per_node"] * st.session_state["node_num"]
+ max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"]
+ if st.session_state["engine_num"] > max_engine_num:
+ st.session_state["engine_num"] = max_engine_num
+ self._set_trainer_gpu_num()
+ st.number_input(
+ "Engine Num",
+ key="engine_num",
+ min_value=1,
+ max_value=max_engine_num,
+ help=f"`engine_num` is used to set the quantity of inference engines, "
+ f"{self._str_for_engine_num_and_tp_size}",
+ on_change=self._set_trainer_gpu_num,
+ )
- max_retry_times_col, max_retry_interval_col = st.columns(2)
- max_retry_times = max_retry_times_col.number_input(
- "Max Retry Times", value=3, min_value=1
+ def _set_tensor_parallel_size(self):
+ total_gpu_num = st.session_state["gpu_per_node"] * st.session_state["node_num"]
+ max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"]
+ if st.session_state["tensor_parallel_size"] > max_tensor_parallel_size:
+ st.session_state["tensor_parallel_size"] = max_tensor_parallel_size
+ self._set_trainer_gpu_num()
+ st.number_input(
+ "Tensor Parallel Size",
+ key="tensor_parallel_size",
+ min_value=1,
+ max_value=max_tensor_parallel_size,
+ help=f"`tensor_parallel_size` is used to set the tensor parallel size of inference engines, "
+ f"{self._str_for_engine_num_and_tp_size}",
+ on_change=self._set_trainer_gpu_num,
+ )
+
+ def _check_engine_num_and_tp_size(self):
+ node_num = st.session_state["node_num"]
+ gpu_per_node = st.session_state["gpu_per_node"]
+ engine_num = st.session_state["engine_num"]
+ tensor_parallel_size = st.session_state["tensor_parallel_size"]
+ if node_num > 1:
+ if gpu_per_node % tensor_parallel_size != 0:
+ self.unfinished_fields.add("tensor_parallel_size")
+ st.warning(
+ "Please ensure that `tensor_parallel_size` is a factor of `gpu_per_node` when `node_num > 1`."
)
- max_retry_interval = max_retry_interval_col.number_input(
- "Max Retry Interval", value=1, min_value=1
+ if engine_num * tensor_parallel_size % gpu_per_node != 0:
+ self.unfinished_fields.add("engine_num")
+ st.warning(
+ "Please ensure that `engine_num * tensor_parallel_size` can be divided by `gpu_per_node` when `node_num > 1`."
)
- sft_warmup_dataset_path = st.text_input("SFT Warmup Dataset Path", "")
- if sft_warmup_dataset_path and "://" not in sft_warmup_dataset_path: # TODO
- (
- sft_warmup_train_split_col,
- sft_warmup_eval_split_col,
- sft_warmup_prompt_key_col,
- sft_warmup_response_key_col,
- ) = st.columns(4)
- sft_warmup_train_split = sft_warmup_train_split_col.text_input( # noqa: F841
- "SFT Train Split", "train"
- )
- sft_warmup_eval_split = sft_warmup_eval_split_col.text_input( # noqa: F841
- "SFT Eval Split", ""
- )
- sft_warmup_prompt_key = sft_warmup_prompt_key_col.text_input( # noqa: F841
- "SFT Prompt Key", "question"
- )
- sft_warmup_response_key = sft_warmup_response_key_col.text_input( # noqa: F841
- "SFT Response Key", "answer"
- )
- else:
- sft_warmup_train_split = "" # noqa: F841
- sft_warmup_eval_split = "" # noqa: F841
- sft_warmup_prompt_key = "" # noqa: F841
- sft_warmup_response_key = "" # noqa: F841
+ def _set_repeat_times(self):
+ if st.session_state["algorithm_type"] == AlgorithmType.OPMD.value or st.session_state[
+ "adv_estimator"
+ ] in [
+ AdvantageEstimator.GRPO.value,
+ AdvantageEstimator.RLOO.value,
+ ]:
+ min_repeat_times = 2
+ else:
+ min_repeat_times = 1
+ if st.session_state["repeat_times"] < min_repeat_times:
+ st.session_state["repeat_times"] = min_repeat_times
+ st.number_input(
+ "Repeat Times",
+ key="repeat_times",
+ min_value=min_repeat_times,
+ help="`repeat_times` is used to set how many experiences each task can generate, "
+ "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
+ )
- with connector_tab:
- (
- engine_type_col,
- engine_num_col,
- tensor_parallel_size_col,
- repeat_times_col,
- ) = st.columns(4)
- engine_type = engine_type_col.selectbox(
- "Explorer Engine Type", ["vllm_async", "vllm"], index=0
- )
- if "engine_num" not in st.session_state:
- st.session_state.engine_num = 2
- old_engine_num = min(st.session_state.engine_num, gpu_per_node * node_num)
- engine_num = engine_num_col.number_input(
- "Engine Num",
- value=old_engine_num,
- min_value=1,
- max_value=gpu_per_node * node_num,
- help="cannot exceed `gpu_per_node` * `node_num`",
- )
- st.session_state.engine_num = engine_num
- tensor_parallel_size = tensor_parallel_size_col.number_input(
- "Tensor Parallel Size", value=1, min_value=1, max_value=8
+ def _set_sync_method(self):
+ st.selectbox(
+ "Sync Method",
+ ["online", "offline"],
+ key="sync_method",
+ help="""`online`: the explorer and trainer sync model weights once every `sync_iteration_interval` steps.
+
+`offline`: the trainer saves the model checkpoint, and the explorer loads it at `sync_iteration_interval`.""",
+ )
+
+ def _set_sync_iteration_interval(self):
+ st.number_input(
+ "Sync Iteration Interval",
+ key="sync_iteration_interval",
+ min_value=1,
+ help="""The iteration interval at which the `explorer` and `trainer` synchronize model weight.""",
+ )
+
+ def _set_runner_num(self):
+ st.number_input("Runner Num", key="runner_num", min_value=1)
+
+ def _set_max_pending_requests(self):
+ st.number_input("Max Pending Requests", key="max_pending_requests", min_value=1)
+
+ def _set_max_waiting_steps(self):
+ st.number_input("Max Waiting Steps", key="max_waiting_steps", min_value=1)
+
+ def _set_dtype(self):
+ st.selectbox("Dtype", ["float16", "bfloat16", "float32"], key="dtype")
+
+ def _set_backend(self):
+ st.selectbox("Backend", ["nccl"], key="backend")
+
+ def _set_temperature(self):
+ st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0)
+
+ def _set_top_p(self):
+ st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0)
+
+ def _set_top_k(self):
+ st.number_input(
+ "Top-k",
+ key="top_k",
+ min_value=-1,
+ max_value=512,
+ help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.",
+ )
+
+ def _set_seed(self):
+ st.number_input("Seed", key="seed", step=1)
+
+ def _set_logprobs(self):
+ st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20)
+
+ def _set_enable_prefix_caching(self):
+ st.checkbox("Enable Prefix Caching", key="enable_prefix_caching")
+
+ def _set_enforce_eager(self):
+ st.checkbox("Enforce Eager", key="enforce_eager")
+
+ def _set_trainer_type(self):
+ st.selectbox("Trainer Type", ["verl"], key="trainer_type")
+
+ def _set_algorithm_type(self):
+ st.selectbox(
+ "Algorithm Type",
+ [AlgorithmType.PPO.value, AlgorithmType.DPO.value, AlgorithmType.OPMD.value],
+ key="algorithm_type",
+ )
+
+ def _set_sft_warmup_iteration(self):
+ st.number_input("SFT Warmup Iteration", key="sft_warmup_iteration", min_value=0)
+
+ def _set_eval_interval(self):
+ st.number_input("Eval Interval", key="eval_interval", min_value=1)
+
+ def _set_training_args(self):
+ st.multiselect(
+ "Training Args",
+ [
+ "balance_batch",
+ "gradient_checkpointing",
+ "remove_padding",
+ "dynamic_bsz",
+ ],
+ key="training_args",
+ )
+
+ def _set_save_freq(self):
+ if st.session_state["sync_method"] == "online":
+ freeze_save_freq = False
+ else:
+ st.session_state["save_freq"] = st.session_state["sync_iteration_interval"]
+ freeze_save_freq = True
+ st.number_input(
+ "Save Freq",
+ key="save_freq",
+ min_value=1,
+ help="Set to `sync_iteration_interval` when `sync_method` is `offline`",
+ disabled=freeze_save_freq,
+ )
+
+ def _set_training_strategy(self):
+ st.selectbox(
+ "Training Strategy",
+ ["fsdp", "megatron"],
+ key="training_strategy",
+ help="megatron is not tested",
+ )
+
+ def _set_param_offload(self):
+ st.checkbox("FSDP Param Offload", key="param_offload")
+
+ def _set_optimizer_offload(self):
+ st.checkbox("FSDP Optimizer Offload", key="optimizer_offload")
+
+ def _set_resume_mode(self):
+ st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], key="resume_mode")
+
+ def _set_resume_from_path(self):
+ if st.session_state["resume_mode"] == "resume_path":
+ st.text_input("Resume Path", key="resume_from_path")
+ if (
+ not st.session_state["resume_from_path"].strip()
+ or "global_step_" not in st.session_state["resume_from_path"]
+ ):
+ self.unfinished_fields.add("resume_from_path")
+ st.warning("Please input a valid resume path when `resume_mode == resume_path`")
+
+ def _set_critic_warmup(self):
+ st.number_input("Critic Warmup Iteration", key="critic_warmup", min_value=0)
+
+ def _set_total_training_steps(self):
+ st.number_input("Total Training Steps", key="total_training_steps", min_value=1)
+
+ def _set_default_hdfs_dir(self):
+ st.text_input("Default HDFS Dir", key="default_hdfs_dir")
+
+ def _set_remove_previous_ckpt_in_save(self):
+ st.checkbox("Remove Previous Checkpoint in Save", key="remove_previous_ckpt_in_save")
+
+ def _set_del_local_ckpt_after_load(self):
+ st.checkbox("Delete Local Checkpoint After Load", key="del_local_ckpt_after_load")
+
+ def _set_max_actor_ckpt_to_keep(self):
+ st.number_input("Max Actor Checkpoint to Keep", key="max_actor_ckpt_to_keep", min_value=1)
+
+ def _set_max_critic_ckpt_to_keep(self):
+ st.number_input("Max Critic Checkpoint to Keep", key="max_critic_ckpt_to_keep", min_value=1)
+
+ def _set_gamma(self):
+ st.number_input("Gamma", key="gamma")
+
+ def _set_lam(self):
+ st.number_input("Lambda", key="lam")
+
+ def _set_adv_estimator(self):
+ st.selectbox(
+ "Advantage Estimator",
+ [member.value for member in AdvantageEstimator],
+ key="adv_estimator",
+ )
+
+ def _set_norm_adv_by_std_in_grpo(self):
+ st.checkbox("Norm Adv by Std in GRPO", key="norm_adv_by_std_in_grpo")
+
+ def _set_use_kl_in_reward(self):
+ st.checkbox("Use KL in Reward", key="use_kl_in_reward")
+
+ def _set_kl_penalty(self):
+ st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], key="kl_penalty")
+
+ def _set_kl_ctrl_type(self):
+ st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], key="kl_ctrl_type")
+
+ def _set_kl_ctrl_coef(self):
+ st.number_input("KL Ctrl Coef", key="kl_ctrl_coef", format="%.1e")
+
+ def _set_horizon(self):
+ st.number_input("Horizon", key="horizon", min_value=1.0)
+
+ def _set_target_kl(self):
+ st.number_input("Target KL", key="target_kl", format="%.1e")
+
+ def _set_actor_ppo_micro_batch_size_per_gpu(self):
+ st.number_input(
+ "Micro Batch Size Per GPU for Actor",
+ key="actor_ppo_micro_batch_size_per_gpu",
+ min_value=1,
+ )
+
+ def _set_ref_log_prob_micro_batch_size_per_gpu(self):
+ st.number_input(
+ "Micro Batch Size Per GPU for Ref",
+ key="ref_log_prob_micro_batch_size_per_gpu",
+ min_value=1,
+ )
+
+ def _set_actor_ulysses_sequence_parallel_size(self):
+ st.number_input(
+ "Ulysses Sequence Parallel Size",
+ key="actor_ulysses_sequence_parallel_size",
+ min_value=1,
+ max_value=8,
+ )
+
+ def _set_actor_lr(self):
+ st.number_input(
+ "Learning Rate for Actor",
+ key="actor_lr",
+ min_value=1e-7,
+ max_value=1e-3,
+ format="%.1e",
+ )
+
+ def _set_actor_warmup_style(self):
+ st.selectbox(
+ "LR Warmup Style for Actor",
+ ["constant", "cosine"],
+ key="actor_warmup_style",
+ )
+
+ def _set_actor_lr_warmup_steps_ratio(self):
+ st.number_input(
+ "LR Warmup Steps Ratio for Actor",
+ key="actor_lr_warmup_steps_ratio",
+ min_value=0.0,
+ max_value=1.0,
+ )
+
+ def _set_actor_grad_clip(self):
+ st.number_input("Grad Clip", key="actor_grad_clip", min_value=0.0, max_value=1.0)
+
+ def _set_actor_clip_ratio(self):
+ st.number_input("Clip Ratio", key="actor_clip_ratio", min_value=0.0, max_value=1.0)
+
+ def _set_actor_entropy_coeff(self):
+ st.number_input(
+ "Entropy Coeff",
+ key="actor_entropy_coeff",
+ min_value=0.0,
+ max_value=1.0,
+ format="%.1e",
+ )
+
+ def _set_actor_use_kl_loss(self):
+ st.checkbox("Use KL Loss", key="actor_use_kl_loss")
+
+ def _set_actor_kl_loss_coef(self):
+ st.number_input(
+ "KL Loss Coef",
+ key="actor_kl_loss_coef",
+ min_value=0.0,
+ max_value=1.0,
+ format="%.1e",
+ )
+
+ def _set_actor_kl_loss_type(self):
+ st.selectbox(
+ "KL Loss Type",
+ ["kl", "abs", "mse", "low_var_kl"],
+ key="actor_kl_loss_type",
+ )
+
+ def _set_actor_tau(self):
+ st.number_input(
+ "Tau for OPMD",
+ key="actor_tau",
+ min_value=0.0,
+ format="%.1e",
+ )
+
+ def _set_actor_opmd_baseline(self):
+ st.selectbox(
+ "OPMD Baseline",
+ ["mean", "logavgexp"],
+ key="actor_opmd_baseline",
+ )
+
+ def _set_actor_use_uid(self):
+ st.checkbox("Use UID for OPMD", key="actor_use_uid")
+
+ def _set_actor_checkpoint(self):
+ st.multiselect(
+ "Checkpoint",
+ ["model", "hf_model", "optimizer", "extra"],
+ key="actor_checkpoint",
+ )
+
+ def _set_critic_ppo_micro_batch_size_per_gpu(self):
+ st.number_input(
+ "Micro Batch Size Per GPU for Critic",
+ key="critic_ppo_micro_batch_size_per_gpu",
+ min_value=1,
+ )
+
+ def _set_critic_ulysses_sequence_parallel_size(self):
+ st.number_input(
+ "Ulysses Sequence Parallel Size",
+ key="critic_ulysses_sequence_parallel_size",
+ min_value=1,
+ max_value=8,
+ )
+
+ def _set_critic_lr(self):
+ st.number_input(
+ "Learning Rate for Critic",
+ key="critic_lr",
+ min_value=1e-7,
+ max_value=1e-3,
+ format="%.1e",
+ )
+
+ def _set_critic_warmup_style(self):
+ st.selectbox(
+ "LR Warmup Style for Critic",
+ ["constant", "cosine"],
+ key="critic_warmup_style",
+ )
+
+ def _set_critic_lr_warmup_steps_ratio(self):
+ st.number_input(
+ "LR Warmup Steps Ratio for Critic",
+ key="critic_lr_warmup_steps_ratio",
+ min_value=0.0,
+ max_value=1.0,
+ )
+
+ def _set_critic_grad_clip(self):
+ st.number_input(
+ "Grad Clip for Critic",
+ key="critic_grad_clip",
+ min_value=0.0,
+ max_value=1.0,
+ )
+
+ def _set_critic_cliprange_value(self):
+ st.number_input(
+ "Cliprange Value",
+ key="critic_cliprange_value",
+ min_value=0.0,
+ max_value=1.0,
+ )
+
+ def _set_training_mode(self):
+ st.selectbox("Training Mode", ["PPO", "GRPO", "DPO", "OPMD"], key="training_mode")
+
+ if st.session_state["training_mode"] == "PPO":
+ st.session_state["algorithm_type"] = AlgorithmType.PPO.value
+ st.session_state["adv_estimator"] = "gae"
+ elif st.session_state["training_mode"] == "GRPO":
+ st.session_state["algorithm_type"] = AlgorithmType.PPO.value
+ st.session_state["adv_estimator"] = "grpo"
+ elif st.session_state["training_mode"] == "DPO":
+ st.session_state["algorithm_type"] = AlgorithmType.DPO.value
+ st.session_state["adv_estimator"] = "grpo"
+ elif st.session_state["training_mode"] == "OPMD":
+ st.session_state["algorithm_type"] = AlgorithmType.OPMD.value
+ st.session_state["adv_estimator"] = "grpo"
+
+ def _set_configs_with_st_columns(
+ self, config_names: List[str], columns_config: List[int] = None
+ ):
+ if columns_config is None:
+ columns_config = len(config_names)
+ columns = st.columns(columns_config)
+ for col, config_name in zip(columns, config_names):
+ with col:
+ getattr(self, f"_set_{config_name}")()
+
+ def beginner_mode(self):
+ st.header("Essential Configs")
+ self._set_configs_with_st_columns(["project", "name"], columns_config=[1, 3])
+
+ self._set_model_path()
+
+ self._set_checkpoint_path()
+
+ self._set_dataset_path()
+
+ self._set_configs_with_st_columns(["training_mode", "sft_warmup_iteration", "monitor_type"])
+ if st.session_state["sft_warmup_iteration"] > 0:
+ self._set_sft_warmup_dataset_path()
+
+ st.header("Important Configs")
+ self._set_configs_with_st_columns(
+ ["node_num", "gpu_per_node", "engine_num", "tensor_parallel_size"]
+ )
+ self._check_engine_num_and_tp_size()
+
+ self._set_configs_with_st_columns(
+ ["total_epochs", "task_num_per_batch", "max_prompt_tokens", "max_response_tokens"]
+ )
+ self._check_task_num_per_batch()
+
+ self._set_dataset_args()
+
+ if st.session_state["sft_warmup_iteration"] > 0:
+ self._set_sft_warmup_dataset_args()
+
+ self._set_configs_with_st_columns(
+ ["default_workflow_type", "default_reward_fn_type", "repeat_times"]
+ )
+
+ self._set_configs_with_st_columns(["sync_iteration_interval", "eval_interval", "save_freq"])
+
+ self._set_actor_use_kl_loss()
+ if st.session_state["actor_use_kl_loss"]:
+ self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"])
+
+ self._set_configs_with_st_columns(
+ [
+ "actor_ppo_micro_batch_size_per_gpu",
+ "actor_lr",
+ "ref_log_prob_micro_batch_size_per_gpu",
+ ]
+ )
+
+ use_critic = st.session_state["adv_estimator"] == "gae" # TODO: may apply to expert mode
+ if use_critic:
+ self._set_configs_with_st_columns(["critic_ppo_micro_batch_size_per_gpu", "critic_lr"])
+
+ def _expert_model_part(self):
+ self._set_configs_with_st_columns(["project", "name"], columns_config=[1, 3])
+
+ self._set_model_path()
+ self._set_critic_model_path()
+
+ self._set_checkpoint_path()
+
+ self._set_configs_with_st_columns(["monitor_type", "node_num", "gpu_per_node"])
+ self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"])
+
+ def _expert_buffer_part(self):
+ self._set_configs_with_st_columns(["total_epochs", "task_num_per_batch"])
+ self._check_task_num_per_batch()
+
+ self._set_dataset_path()
+
+ self._set_dataset_args()
+
+ self._set_configs_with_st_columns(
+ ["default_workflow_type", "default_reward_fn_type", "storage_type"]
+ )
+
+ self.buffer_advanced_tab = st.expander("Advanced Config")
+ with self.buffer_advanced_tab:
+ self._set_db_url()
+
+ self._set_configs_with_st_columns(["max_retry_times", "max_retry_interval"])
+
+ self._set_sft_warmup_dataset_path()
+ self._set_sft_warmup_dataset_args()
+
+ def _expert_connector_part(self):
+ self._set_configs_with_st_columns(
+ ["engine_type", "engine_num", "tensor_parallel_size", "repeat_times"]
+ )
+ self._check_engine_num_and_tp_size()
+
+ self._set_configs_with_st_columns(["sync_method", "sync_iteration_interval"])
+
+ with st.expander("Advanced Config"):
+ self._set_configs_with_st_columns(
+ ["runner_num", "max_pending_requests", "max_waiting_steps", "dtype"]
)
- repeat_times = repeat_times_col.number_input("Repeat Times", value=1, min_value=1)
- sync_method_col, sync_iteration_interval_col = st.columns(2)
- sync_method = sync_method_col.selectbox("Sync Method", ["online", "offline"], index=0)
- sync_iteration_interval = sync_iteration_interval_col.number_input(
- "Sync Iteration Interval", value=10, min_value=1
+ self._set_configs_with_st_columns(
+ ["backend", "temperature", "top_p", "top_k", "seed", "logprobs"]
)
+
+ self._set_configs_with_st_columns(["enable_prefix_caching", "enforce_eager"])
+
+ def _expert_trainer_part(self):
+ self._set_configs_with_st_columns(
+ ["trainer_type", "algorithm_type", "sft_warmup_iteration", "eval_interval"]
+ )
+ self._check_sft_warmup_dataset_path()
+
+ if st.session_state["trainer_type"] == "verl":
+ self._expert_verl_trainer_part()
+
+ def _expert_verl_trainer_part(self):
+ rl_training_tab, rl_algorithm_tab, actor_ref_tab, critic_tab = st.tabs(
+ [
+ "RL Training Config",
+ "RL Algorithm Config",
+ "Actor and Ref Config",
+ "Critic Config",
+ ]
+ )
+ with rl_training_tab:
+ st.subheader("RL Training Config")
+ self._set_training_args()
+
+ self._set_configs_with_st_columns(["save_freq", "training_strategy", "resume_mode"])
+
+ if st.session_state["training_strategy"] == "fsdp":
+ self._set_configs_with_st_columns(["param_offload", "optimizer_offload"])
+ self._set_resume_from_path()
+
with st.expander("Advanced Config"):
- (
- runner_num_col,
- max_pending_requests_col,
- max_waiting_steps_col,
- dtype_col,
- ) = st.columns(4)
- runner_num = runner_num_col.number_input("Runner Num", value=32, min_value=1)
- max_pending_requests = max_pending_requests_col.number_input(
- "Max Pending Requests", value=32, min_value=1
- )
- max_waiting_steps = max_waiting_steps_col.number_input(
- "Max Waiting Steps", value=4, min_value=1
- )
- dtype = dtype_col.selectbox("Dtype", ["float16", "bfloat16", "float32"], index=1)
-
- (
- backend_col,
- temperature_col,
- top_p_col,
- top_k_col,
- seed_col,
- logprobs_col,
- ) = st.columns(6)
- backend = backend_col.selectbox("Backend", ["nccl"], index=0)
- temperature = temperature_col.number_input(
- "Temperature", value=1.0, min_value=0.0, max_value=2.0
+ self._set_configs_with_st_columns(["critic_warmup", "total_training_steps"])
+
+ self._set_default_hdfs_dir()
+
+ self._set_configs_with_st_columns(
+ ["remove_previous_ckpt_in_save", "del_local_ckpt_after_load"]
)
- top_p = top_p_col.number_input("Top P", value=1.0, min_value=0.0, max_value=1.0)
- top_k = top_k_col.number_input("Top K", value=1, min_value=1, max_value=512)
- seed = seed_col.number_input("Seed", value=42)
- logprobs = logprobs_col.number_input("Logprobs", value=0, min_value=0, max_value=20)
-
- enable_prefix_caching_col, enforce_eager_col = st.columns(2)
- enable_prefix_caching = enable_prefix_caching_col.checkbox(
- "Enable Prefix Caching", value=False
+
+ self._set_configs_with_st_columns(
+ ["max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep"]
)
- enforce_eager = enforce_eager_col.checkbox("Enforce Eager", value=True)
- gpu_num = gpu_per_node * node_num - engine_num
+ with rl_algorithm_tab:
+ st.subheader("RL Algorithm Config")
+ self._set_configs_with_st_columns(["gamma", "lam", "adv_estimator"])
+ self._set_configs_with_st_columns(["norm_adv_by_std_in_grpo", "use_kl_in_reward"])
+ self._set_configs_with_st_columns(["kl_penalty", "kl_ctrl_type", "kl_ctrl_coef"])
+ self._set_configs_with_st_columns(["horizon", "target_kl"])
- with trainer_tab:
- trainer_type_col, sft_warmup_iteration_col, eval_interval_col = st.columns(3)
- trainer_type = trainer_type_col.selectbox("Trainer Type", ["verl"], index=0)
- sft_warmup_iteration = sft_warmup_iteration_col.number_input(
- "SFT Warmup Iteration", value=0, min_value=0
+ with actor_ref_tab:
+ st.subheader("Actor Model Config")
+ self._set_configs_with_st_columns(
+ [
+ "actor_ppo_micro_batch_size_per_gpu",
+ "ref_log_prob_micro_batch_size_per_gpu",
+ "actor_ulysses_sequence_parallel_size",
+ ]
)
- if sft_warmup_iteration and not sft_warmup_dataset_path.strip():
- self.unfinished_flag = True
- st.warning(
- "Please input SFT warmup dataset path when `sft_warmup_iteration` is not 0"
- )
- with buffer_advanced_tab:
- st.warning(
- "Please input SFT warmup dataset path when `sft_warmup_iteration` is not 0"
- )
- eval_interval = eval_interval_col.number_input("Eval Interval", value=1000, min_value=1)
- if trainer_type == "verl":
- trainer_config_path = st.text_input("Trainer Config Path", "")
- if not trainer_config_path.strip():
- self.unfinished_flag = True
- st.warning("Please input trainer config path")
-
- rl_training_tab, rl_algorithm_tab, actor_ref_tab, critic_tab = st.tabs(
- [
- "RL Training Config",
- "RL Algorithm Config",
- "Actor and Ref Config",
- "Critic Config",
- ]
+
+ self._set_configs_with_st_columns(
+ ["actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio"]
+ )
+
+ self._set_configs_with_st_columns(
+ ["actor_grad_clip", "actor_clip_ratio", "actor_entropy_coeff"]
+ )
+
+ self._set_actor_use_kl_loss()
+ if st.session_state["actor_use_kl_loss"]:
+ self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"])
+
+ if st.session_state["algorithm_type"] == "opmd":
+ self._set_configs_with_st_columns(
+ ["actor_tau", "actor_opmd_baseline", "actor_use_uid"]
)
- with rl_training_tab:
- st.subheader("RL Training Config")
- training_args = st.multiselect(
- "Training Args",
- [
- "balance_batch",
- "gradient_checkpointing",
- "remove_padding",
- "dynamic_bsz",
- ],
- default=[
- "balance_batch",
- "gradient_checkpointing",
- "remove_padding",
- "dynamic_bsz",
- ],
- )
- balance_batch = "balance_batch" in training_args
- enable_gradient_checkpointing = "gradient_checkpointing" in training_args
- use_remove_padding = "remove_padding" in training_args
- use_dynamic_bsz = "dynamic_bsz" in training_args
-
- (
- save_freq_col,
- training_strategy_col,
- resume_mode_col,
- ) = st.columns(3)
- if "save_freq" not in st.session_state:
- st.session_state.save_freq = 100
- if sync_method == "online":
- save_freq = save_freq_col.number_input(
- "Save Freq",
- value=st.session_state.save_freq,
- min_value=1,
- help="Set to `sync_iteration_interval` when `sync_method` is `offline`",
- )
- st.session_state.save_freq = save_freq
- else:
- st.session_state.save_freq = sync_iteration_interval
- save_freq = save_freq_col.number_input(
- "Save Freq",
- value=st.session_state.save_freq,
- min_value=1,
- help="Set to `sync_iteration_interval` when `sync_method` is `offline`",
- disabled=True,
- )
-
- training_strategy = training_strategy_col.selectbox(
- "Training Strategy",
- ["fsdp", "megatron"],
- index=0,
- help="megatron is not tested",
- )
- if training_strategy == "fsdp":
- param_offload_col, optimizer_offload_col = st.columns(2)
- param_offload = param_offload_col.checkbox(
- "FSDP Param Offload", value=False
- )
- optimizer_offload = optimizer_offload_col.checkbox(
- "FSDP Optimizer Offload", value=False
- )
- fsdp_config = {
- "wrap_policy": {"min_num_params": 0},
- "param_offload": param_offload,
- "optimizer_offload": optimizer_offload,
- "fsdp_size": -1,
- }
- else:
- fsdp_config = {}
-
- resume_mode = resume_mode_col.selectbox(
- "Resume Mode", ["disable", "auto", "resume_path"], index=1
- )
- if "resume_from_path" not in st.session_state:
- st.session_state.resume_from_path = ""
- if resume_mode == "resume_path":
- resume_from_path = st.text_input(
- "Resume Path", st.session_state.resume_from_path
- )
- st.session_state.resume_from_path = resume_from_path
- if not resume_from_path.strip() or "global_step_" not in resume_from_path:
- self.unfinished_flag = True
- st.warning(
- "Please input a valid resume path when `resume_mode` is `resume_path`"
- )
- else:
- resume_from_path = st.session_state.resume_from_path
-
- with st.expander("Advanced Config"):
- critic_warmup_col, total_training_steps_col = st.columns(2)
- critic_warmup = critic_warmup_col.number_input(
- "Critic Warmup Iteration", value=0, min_value=0
- )
- total_training_steps = total_training_steps_col.number_input(
- "Total Training Steps", value=None, min_value=1
- )
-
- default_hdfs_dir = st.text_input("Default HDFS Dir", None)
-
- (
- remove_previous_ckpt_in_save_col,
- del_local_ckpt_after_load_col,
- ) = st.columns(2)
- remove_previous_ckpt_in_save = remove_previous_ckpt_in_save_col.checkbox(
- "Remove Previous Checkpoint in Save", value=False
- )
- del_local_ckpt_after_load = del_local_ckpt_after_load_col.checkbox(
- "Delete Local Checkpoint After Load", value=False
- )
-
- max_actor_ckpt_to_keep_col, max_critic_ckpt_to_keep_col = st.columns(2)
- max_actor_ckpt_to_keep = max_actor_ckpt_to_keep_col.number_input(
- "Max Actor Checkpoint to Keep", value=None, min_value=1
- )
- max_critic_ckpt_to_keep = max_critic_ckpt_to_keep_col.number_input(
- "Max Critic Checkpoint to Keep", value=None, min_value=1
- )
-
- with rl_algorithm_tab:
- st.subheader("RL Algorithm Config")
- gamma_col, lam_col, adv_estimator_col = st.columns(3)
- gamma = gamma_col.number_input("Gamma", value=1.0)
- lam = lam_col.number_input("lam", value=1.0)
- adv_estimator = adv_estimator_col.selectbox(
- "Advantage Estimator",
- [member.value for member in AdvantageEstimator],
- index=0,
- )
- kl_penalty_col, kl_ctrl_type_col, kl_ctrl_coef_col = st.columns(3)
- kl_penalty = kl_penalty_col.selectbox(
- "KL Penalty", ["kl", "abs", "mse", "low_var_kl"], index=0
- )
- kl_ctrl_type = kl_ctrl_type_col.selectbox(
- "KL Ctrl Type", ["fixed", "adaptive"], index=0
- )
- kl_ctrl_coef = kl_ctrl_coef_col.number_input("KL Ctrl Coef", value=0.001)
-
- with actor_ref_tab:
- st.subheader("Actor Model Config")
- (
- actor_ppo_micro_batch_size_per_gpu_col,
- ref_log_prob_micro_batch_size_per_gpu_col,
- actor_ulysses_sequence_parallel_size_col,
- ) = st.columns(3)
- actor_ppo_micro_batch_size_per_gpu = (
- actor_ppo_micro_batch_size_per_gpu_col.number_input(
- "Micro Batch Size Per GPU for Actor", value=4, min_value=1
- )
- )
- ref_log_prob_micro_batch_size_per_gpu = (
- ref_log_prob_micro_batch_size_per_gpu_col.number_input(
- "Micro Batch Size Per GPU for Ref", value=8, min_value=1
- )
- )
- actor_ulysses_sequence_parallel_size = (
- actor_ulysses_sequence_parallel_size_col.number_input(
- "Ulysses Sequence Parallel Size", value=1, min_value=1, max_value=8
- )
- )
-
- (
- actor_lr_col,
- actor_warmup_style_col,
- actor_lr_warmup_steps_ratio_col,
- ) = st.columns(3)
- actor_lr = actor_lr_col.number_input(
- "Learning Rate for actor",
- value=1e-6,
- min_value=1e-7,
- max_value=1e-3,
- format="%.1e",
- )
- actor_warmup_style = actor_warmup_style_col.selectbox(
- "LR Warmup Style", ["constant", "cosine"], index=0
- )
- actor_lr_warmup_steps_ratio = actor_lr_warmup_steps_ratio_col.number_input(
- "LR Warmup Steps Ratio", value=0.0, min_value=0.0, max_value=1.0
- )
-
- (
- actor_alg_type_col,
- actor_grad_clip_col,
- actor_clip_ratio_col,
- actor_entropy_coeff_col,
- ) = st.columns(4)
- actor_alg_type = actor_alg_type_col.selectbox(
- "Algorithm Type", ["ppo", "opmd", "pairwise_opmd"], index=0
- )
- if "actor_tau" not in st.session_state:
- st.session_state.actor_tau = 0.0
- st.session_state.actor_opmd_baseline = "mean"
- st.session_state.actor_use_uid = False
- if actor_alg_type != "ppo":
- actor_tau_col, actor_opmd_baseline_col, actor_use_uid_col = st.columns(3)
- actor_tau = actor_tau_col.number_input(
- "Tau for OPMD",
- value=0.0,
- min_value=0.0,
- max_value=1.0,
- format="%.1e",
- )
- actor_opmd_baseline = actor_opmd_baseline_col.selectbox(
- "OPMD Baseline",
- ["mean", "logavgexp"],
- index=0,
- )
- actor_use_uid = actor_use_uid_col.checkbox("Use UID for OPMD", value=False)
- st.session_state.actor_tau = actor_tau
- st.session_state.actor_opmd_baseline = actor_opmd_baseline
- st.session_state.actor_use_uid = actor_use_uid
- else:
- actor_tau = st.session_state.actor_tau
- actor_opmd_baseline = st.session_state.actor_opmd_baseline
- actor_use_uid = st.session_state.actor_use_uid
-
- actor_grad_clip = actor_grad_clip_col.number_input(
- "Grad Clip", value=1.0, min_value=0.0, max_value=1.0
- )
- actor_clip_ratio = actor_clip_ratio_col.number_input(
- "Clip Ratio", value=0.2, min_value=0.0, max_value=1.0
- )
- actor_entropy_coeff = actor_entropy_coeff_col.number_input(
- "Entropy Coeff", value=0.001, min_value=0.0, max_value=1.0
- )
-
- actor_use_kl_loss = st.checkbox("Use KL Loss (True for GRPO)", value=False)
- if "actor_kl_loss_coef" not in st.session_state:
- st.session_state.actor_kl_loss_coef = 0.001
- st.session_state.actor_kl_loss_type = "low_var_kl"
- if actor_use_kl_loss:
- actor_kl_loss_coef_col, actor_kl_loss_type_col = st.columns(2)
- actor_kl_loss_coef = actor_kl_loss_coef_col.number_input(
- "KL Loss Coef",
- value=st.session_state.actor_kl_loss_coef,
- min_value=0.0,
- max_value=1.0,
- format="%.1e",
- )
- actor_kl_loss_type_candidates = ["kl", "abs", "mse", "low_var_kl"]
- actor_kl_loss_type = actor_kl_loss_type_col.selectbox(
- "KL Loss Type",
- actor_kl_loss_type_candidates,
- index=actor_kl_loss_type_candidates.index(
- st.session_state.actor_kl_loss_type
- ),
- )
- st.session_state.actor_kl_loss_coef = actor_kl_loss_coef
- st.session_state.actor_kl_loss_type = actor_kl_loss_type
- else:
- actor_kl_loss_coef = st.session_state.actor_kl_loss_coef
- actor_kl_loss_type = st.session_state.actor_kl_loss_type
-
- actor_checkpoint = st.multiselect(
- "Checkpoint",
- ["model", "hf_model", "optimizer", "extra"],
- default=["model", "hf_model", "optimizer", "extra"],
- )
-
- with critic_tab:
- st.subheader("Critic Model Config")
- (
- critic_lr_col,
- critic_warmup_style_col,
- critic_lr_warmup_steps_ratio_col,
- critic_grad_clip_col,
- ) = st.columns(4)
- critic_lr = critic_lr_col.number_input(
- "Learning Rate",
- key="Learning Rate for Critic",
- value=1e-6,
- min_value=1e-7,
- max_value=1e-3,
- format="%.1e",
- )
- critic_warmup_style = critic_warmup_style_col.selectbox(
- "LR Warmup Style",
- ["constant", "cosine"],
- key="LR Warmup Style for Critic",
- index=0,
- )
- critic_lr_warmup_steps_ratio = critic_lr_warmup_steps_ratio_col.number_input(
- "LR Warmup Steps Ratio",
- key="LR Warmup Steps Ratio for Critic",
- value=0.0,
- min_value=0.0,
- max_value=1.0,
- )
- critic_grad_clip = critic_grad_clip_col.number_input(
- "Grad Clip",
- key="Grad Clip for Critic",
- value=1.0,
- min_value=0.0,
- max_value=1.0,
- )
-
- (
- critic_cliprange_value_col,
- critic_ppo_micro_batch_size_per_gpu_col,
- critic_ulysses_sequence_parallel_size_col,
- ) = st.columns(3)
- critic_cliprange_value = critic_cliprange_value_col.number_input(
- "Cliprange Value",
- key="Cliprange Value for Critic",
- value=0.5,
- min_value=0.0,
- max_value=1.0,
- )
- critic_ppo_micro_batch_size_per_gpu = (
- critic_ppo_micro_batch_size_per_gpu_col.number_input(
- "Micro Batch Size Per GPU for Critic", value=8, min_value=1
- )
- )
- critic_ulysses_sequence_parallel_size = (
- critic_ulysses_sequence_parallel_size_col.number_input(
- "Ulysses Sequence Parallel Size",
- key="Ulysses Sequence Parallel Size for Critic",
- value=1,
- min_value=1,
- max_value=8,
- )
- )
-
- rollout_node_num = engine_num * tensor_parallel_size // gpu_per_node
- trainer_nnodes = node_num - rollout_node_num
- if node_num == 1:
- trainer_n_gpus_per_node = gpu_per_node - engine_num * tensor_parallel_size
- else:
- trainer_n_gpus_per_node = gpu_per_node
-
- if trainer_type == "verl":
- trainer_config = {
- "data": {
- "tokenizer": None,
- "train_files": "placeholder",
- "val_files": "placeholder",
- "prompt_key": "placeholder",
- "max_prompt_length": max_prompt_tokens,
- "max_response_length": max_response_tokens,
- "train_batch_size": batch_size_per_gpu * gpu_num * repeat_times,
- "val_batch_size": None,
- "return_raw_input_ids": False,
- "return_raw_chat": False,
- "shuffle": True,
- "filter_overlong_prompts": False,
- "truncation": "error",
- "image_key": "images",
- },
- "actor_rollout_ref": {
- "hybrid_engine": True,
- "model": {
- "path": model_path,
- "external_lib": None,
- "override_config": {},
- "enable_gradient_checkpointing": enable_gradient_checkpointing,
- "use_remove_padding": use_remove_padding,
- },
- "actor": {
- "strategy": training_strategy,
- "ppo_mini_batch_size": batch_size_per_gpu * gpu_num,
- "ppo_micro_batch_size_per_gpu": actor_ppo_micro_batch_size_per_gpu,
- "use_dynamic_bsz": use_dynamic_bsz,
- "ppo_max_token_len_per_gpu": repeat_times
- * (max_prompt_tokens + max_response_tokens),
- "grad_clip": actor_grad_clip,
- "clip_ratio": actor_clip_ratio,
- "entropy_coeff": actor_entropy_coeff,
- "use_kl_loss": actor_use_kl_loss,
- "kl_loss_coef": actor_kl_loss_coef,
- "kl_loss_type": actor_kl_loss_type,
- "ppo_epochs": 1, # TODO
- "shuffle": False,
- "ulysses_sequence_parallel_size": actor_ulysses_sequence_parallel_size,
- "checkpoint": {"contents": actor_checkpoint},
- "optim": {
- "lr": actor_lr,
- "lr_warmup_steps_ratio": actor_lr_warmup_steps_ratio,
- "warmup_style": actor_warmup_style,
- "total_training_steps": -1
- if total_training_steps is None
- else total_training_steps,
- },
- "fsdp_config": fsdp_config,
- "alg_type": actor_alg_type,
- "tau": actor_tau,
- "opmd_baseline": actor_opmd_baseline,
- "use_uid": actor_use_uid,
- },
- "ref": {
- "fsdp_config": fsdp_config,
- "log_prob_micro_batch_size_per_gpu": ref_log_prob_micro_batch_size_per_gpu,
- "log_prob_use_dynamic_bsz": "${actor_rollout_ref.actor.use_dynamic_bsz}",
- "log_prob_max_token_len_per_gpu": "${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}",
- "ulysses_sequence_parallel_size": "${actor_rollout_ref.actor.ulysses_sequence_parallel_size}",
- },
- "rollout": {
- "name": "vllm",
- "temperature": temperature,
- "top_k": -1,
- "top_p": 1,
- "use_fire_sampling": False,
- "prompt_length": "${data.max_prompt_length}",
- "response_length": "${data.max_response_length}",
- "dtype": "bfloat16",
- "gpu_memory_utilization": 0.4,
- "ignore_eos": False,
- "enforce_eager": True,
- "free_cache_engine": True,
- "load_format": "dummy_dtensor",
- "tensor_model_parallel_size": 2,
- "max_num_batched_tokens": 8192,
- "max_model_len": None,
- "max_num_seqs": 1024,
- "log_prob_micro_batch_size_per_gpu": 4,
- "log_prob_use_dynamic_bsz": "${actor_rollout_ref.actor.use_dynamic_bsz}",
- "log_prob_max_token_len_per_gpu": "${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}",
- "disable_log_stats": True,
- "enable_chunked_prefill": True,
- "do_sample": True,
- "n": repeat_times,
- },
+
+ self._set_actor_checkpoint()
+
+ with critic_tab:
+ st.subheader("Critic Model Config")
+ self._set_configs_with_st_columns(
+ ["critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"]
+ )
+
+ self._set_configs_with_st_columns(
+ ["critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio"]
+ )
+
+ self._set_configs_with_st_columns(["critic_grad_clip", "critic_cliprange_value"])
+
+ def expert_mode(self):
+ model_tab, buffer_tab, connector_tab, trainer_tab = st.tabs(
+ ["Model", "Data", "Explorer and Synchronizer", "Trainer"]
+ )
+ with model_tab:
+ self._expert_model_part()
+
+ with buffer_tab:
+ self._expert_buffer_part()
+
+ with connector_tab:
+ self._expert_connector_part()
+
+ with trainer_tab:
+ self._expert_trainer_part()
+
+ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node: int = 8):
+ balance_batch = "balance_batch" in st.session_state["training_args"]
+ enable_gradient_checkpointing = (
+ "gradient_checkpointing" in st.session_state["training_args"]
+ )
+ use_remove_padding = "remove_padding" in st.session_state["training_args"]
+ use_dynamic_bsz = "dynamic_bsz" in st.session_state["training_args"]
+
+ if st.session_state["training_strategy"] == "fsdp":
+ fsdp_config = {
+ "wrap_policy": {"min_num_params": 0},
+ "param_offload": st.session_state["param_offload"],
+ "optimizer_offload": st.session_state["optimizer_offload"],
+ "fsdp_size": -1,
+ }
+ else:
+ fsdp_config = {}
+
+ ppo_epochs = 1 # TODO
+ ppo_max_token_len_per_gpu = st.session_state["repeat_times"] * (
+ st.session_state["max_prompt_tokens"] + st.session_state["max_response_tokens"]
+ )
+
+ critic_model_path = (
+ st.session_state["critic_model_path"].strip()
+ if st.session_state["critic_model_path"].strip()
+ else st.session_state["model_path"]
+ )
+ trainer_config = {
+ "data": {
+ "tokenizer": None,
+ "train_files": "placeholder",
+ "val_files": "placeholder",
+ "prompt_key": "placeholder",
+ "max_prompt_length": st.session_state["max_prompt_tokens"],
+ "max_response_length": st.session_state["max_response_tokens"],
+ "train_batch_size": st.session_state["task_num_per_batch"]
+ * st.session_state["repeat_times"],
+ "val_batch_size": None,
+ "return_raw_input_ids": False,
+ "return_raw_chat": False,
+ "shuffle": True,
+ "filter_overlong_prompts": False,
+ "truncation": "error",
+ "image_key": "images",
+ },
+ "actor_rollout_ref": {
+ "hybrid_engine": True,
+ "model": {
+ "path": st.session_state["model_path"],
+ "external_lib": None,
+ "override_config": {},
+ "enable_gradient_checkpointing": enable_gradient_checkpointing,
+ "use_remove_padding": use_remove_padding,
},
- "critic": {
- "strategy": training_strategy,
+ "actor": {
+ "strategy": st.session_state["training_strategy"],
+ "ppo_mini_batch_size": st.session_state["task_num_per_batch"],
+ "ppo_micro_batch_size_per_gpu": st.session_state[
+ "actor_ppo_micro_batch_size_per_gpu"
+ ],
+ "use_dynamic_bsz": use_dynamic_bsz,
+ "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu,
+ "grad_clip": st.session_state["actor_grad_clip"],
+ "clip_ratio": st.session_state["actor_clip_ratio"],
+ "entropy_coeff": st.session_state["actor_entropy_coeff"],
+ "use_kl_loss": st.session_state["actor_use_kl_loss"],
+ "kl_loss_coef": st.session_state["actor_kl_loss_coef"],
+ "kl_loss_type": st.session_state["actor_kl_loss_type"],
+ "ppo_epochs": ppo_epochs,
+ "shuffle": False,
+ "ulysses_sequence_parallel_size": st.session_state[
+ "actor_ulysses_sequence_parallel_size"
+ ],
+ "checkpoint": {"contents": st.session_state["actor_checkpoint"]},
"optim": {
- "lr": critic_lr,
- "lr_warmup_steps_ratio": critic_warmup_style,
- "warmup_style": critic_lr_warmup_steps_ratio,
+ "lr": st.session_state["actor_lr"],
+ "lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"],
+ "warmup_style": st.session_state["actor_warmup_style"],
"total_training_steps": -1
- if total_training_steps is None
- else total_training_steps,
- },
- "model": {
- "path": critic_model_path,
- "tokenizer_path": "${actor_rollout_ref.model.path}",
- "override_config": {},
- "external_lib": "${actor_rollout_ref.model.external_lib}",
- "enable_gradient_checkpointing": enable_gradient_checkpointing,
- "use_remove_padding": use_remove_padding,
- "fsdp_config": fsdp_config,
+ if st.session_state["total_training_steps"] is None
+ else st.session_state["total_training_steps"],
},
- "ppo_mini_batch_size": "${actor_rollout_ref.actor.ppo_mini_batch_size}",
- "ppo_micro_batch_size_per_gpu": critic_ppo_micro_batch_size_per_gpu,
- "forward_micro_batch_size_per_gpu": "${critic.ppo_micro_batch_size_per_gpu}",
- "use_dynamic_bsz": use_dynamic_bsz,
- "ppo_max_token_len_per_gpu": repeat_times
- * (max_prompt_tokens + max_response_tokens)
- * 2,
- "forward_max_token_len_per_gpu": "${critic.ppo_max_token_len_per_gpu}",
- "ulysses_sequence_parallel_size": critic_ulysses_sequence_parallel_size,
- "ppo_epochs": "${actor_rollout_ref.actor.ppo_epochs}",
- "shuffle": "${actor_rollout_ref.actor.shuffle}",
- "grad_clip": critic_grad_clip,
- "cliprange_value": critic_cliprange_value,
+ "fsdp_config": copy.deepcopy(fsdp_config),
+ "alg_type": st.session_state["algorithm_type"],
+ "tau": st.session_state["actor_tau"],
+ "opmd_baseline": st.session_state["actor_opmd_baseline"],
+ "use_uid": st.session_state["actor_use_uid"],
},
- "reward_model": {
- "enable": False,
- "strategy": "fsdp",
- "model": {
- "input_tokenizer": "${actor_rollout_ref.model.path}",
- "path": "~/models/FsfairX-LLaMA3-RM-v0.1",
- "external_lib": "${actor_rollout_ref.model.external_lib}",
- "use_remove_padding": False,
- "fsdp_config": {
- "min_num_params": 0,
- "param_offload": False,
- "fsdp_size": -1,
- },
- },
- "ulysses_sequence_parallel_size": 1,
- "use_dynamic_bsz": "${critic.use_dynamic_bsz}",
- "forward_max_token_len_per_gpu": "${critic.forward_max_token_len_per_gpu}",
- "reward_manager": "naive",
+ "ref": {
+ "fsdp_config": copy.deepcopy(fsdp_config),
+ "log_prob_micro_batch_size_per_gpu": st.session_state[
+ "ref_log_prob_micro_batch_size_per_gpu"
+ ],
+ "log_prob_use_dynamic_bsz": use_dynamic_bsz,
+ "log_prob_max_token_len_per_gpu": ppo_max_token_len_per_gpu,
+ "ulysses_sequence_parallel_size": st.session_state[
+ "actor_ulysses_sequence_parallel_size"
+ ],
},
- "custom_reward_function": {"path": None, "name": "compute_score"},
- "algorithm": {
- "gamma": gamma,
- "lam": lam,
- "adv_estimator": adv_estimator,
- "kl_penalty": kl_penalty,
- "kl_ctrl": {"type": kl_ctrl_type, "kl_coef": kl_ctrl_coef},
+ "rollout": {
+ "name": "vllm",
+ "temperature": st.session_state["temperature"],
+ "top_k": -1,
+ "top_p": 1,
+ "use_fire_sampling": False,
+ "prompt_length": st.session_state["max_prompt_tokens"],
+ "response_length": st.session_state["max_response_tokens"],
+ "dtype": "bfloat16",
+ "gpu_memory_utilization": 0.4,
+ "ignore_eos": False,
+ "enforce_eager": True,
+ "free_cache_engine": True,
+ "load_format": "dummy_dtensor",
+ "tensor_model_parallel_size": 2,
+ "max_num_batched_tokens": 8192,
+ "max_model_len": None,
+ "max_num_seqs": 1024,
+ "log_prob_micro_batch_size_per_gpu": 4,
+ "log_prob_use_dynamic_bsz": use_dynamic_bsz,
+ "log_prob_max_token_len_per_gpu": ppo_max_token_len_per_gpu,
+ "disable_log_stats": True,
+ "enable_chunked_prefill": True,
+ "do_sample": True,
+ "n": st.session_state["repeat_times"],
},
- "trainer": {
- "balance_batch": balance_batch,
- "total_epochs": total_epoch,
- "project_name": project,
- "experiment_name": name,
- "logger": ["wandb"],
- "val_generations_to_log_to_wandb": 0,
- "nnodes": trainer_nnodes,
- "n_gpus_per_node": trainer_n_gpus_per_node,
- "save_freq": save_freq,
- "resume_mode": resume_mode,
- "resume_from_path": resume_from_path,
- "test_freq": 100,
- "critic_warmup": critic_warmup,
- "default_hdfs_dir": default_hdfs_dir,
- "remove_previous_ckpt_in_save": remove_previous_ckpt_in_save,
- "del_local_ckpt_after_load": del_local_ckpt_after_load,
- "default_local_dir": checkpoint_path,
- "val_before_train": False,
- "sync_freq": sync_iteration_interval,
- "max_actor_ckpt_to_keep": max_actor_ckpt_to_keep,
- "max_critic_ckpt_to_keep": max_critic_ckpt_to_keep,
+ },
+ "critic": {
+ "strategy": st.session_state["training_strategy"],
+ "optim": {
+ "lr": st.session_state["critic_lr"],
+ "lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"],
+ "warmup_style": st.session_state["critic_warmup_style"],
+ "total_training_steps": -1
+ if st.session_state["total_training_steps"] is None
+ else st.session_state["total_training_steps"],
},
- }
+ "model": {
+ "path": critic_model_path,
+ "tokenizer_path": critic_model_path,
+ "override_config": {},
+ "external_lib": None,
+ "enable_gradient_checkpointing": enable_gradient_checkpointing,
+ "use_remove_padding": use_remove_padding,
+ "fsdp_config": copy.deepcopy(fsdp_config),
+ },
+ "ppo_mini_batch_size": st.session_state["task_num_per_batch"],
+ "ppo_micro_batch_size_per_gpu": st.session_state[
+ "critic_ppo_micro_batch_size_per_gpu"
+ ],
+ "forward_micro_batch_size_per_gpu": st.session_state[
+ "critic_ppo_micro_batch_size_per_gpu"
+ ],
+ "use_dynamic_bsz": use_dynamic_bsz,
+ "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu * 2,
+ "forward_max_token_len_per_gpu": ppo_max_token_len_per_gpu * 2,
+ "ulysses_sequence_parallel_size": st.session_state[
+ "critic_ulysses_sequence_parallel_size"
+ ],
+ "ppo_epochs": ppo_epochs,
+ "shuffle": False,
+ "grad_clip": st.session_state["critic_grad_clip"],
+ "cliprange_value": st.session_state["critic_cliprange_value"],
+ },
+ "reward_model": {
+ "enable": False,
+ "strategy": "fsdp",
+ "model": {
+ "input_tokenizer": st.session_state["model_path"],
+ "path": "~/models/FsfairX-LLaMA3-RM-v0.1",
+ "external_lib": None,
+ "use_remove_padding": False,
+ "fsdp_config": {
+ "min_num_params": 0,
+ "param_offload": False,
+ "fsdp_size": -1,
+ },
+ },
+ "ulysses_sequence_parallel_size": 1,
+ "use_dynamic_bsz": use_dynamic_bsz,
+ "forward_max_token_len_per_gpu": ppo_max_token_len_per_gpu * 2,
+ "reward_manager": "naive",
+ },
+ "custom_reward_function": {"path": None, "name": "compute_score"},
+ "algorithm": {
+ "gamma": st.session_state["gamma"],
+ "lam": st.session_state["lam"],
+ "adv_estimator": st.session_state["adv_estimator"],
+ "kl_penalty": st.session_state["kl_penalty"],
+ "kl_ctrl": {
+ "type": st.session_state["kl_ctrl_type"],
+ "kl_coef": st.session_state["kl_ctrl_coef"],
+ },
+ },
+ "trainer": {
+ "balance_batch": balance_batch,
+ "total_epochs": st.session_state["total_epochs"],
+ "project_name": st.session_state["project"],
+ "experiment_name": st.session_state["exp_name"],
+ "logger": ["wandb"],
+ "val_generations_to_log_to_wandb": 0,
+ "nnodes": trainer_nnodes,
+ "n_gpus_per_node": trainer_n_gpus_per_node,
+ "save_freq": st.session_state["save_freq"],
+ "resume_mode": st.session_state["resume_mode"],
+ "resume_from_path": st.session_state["resume_from_path"],
+ "test_freq": 100,
+ "critic_warmup": st.session_state["critic_warmup"],
+ "default_hdfs_dir": st.session_state["default_hdfs_dir"],
+ "remove_previous_ckpt_in_save": st.session_state["remove_previous_ckpt_in_save"],
+ "del_local_ckpt_after_load": st.session_state["del_local_ckpt_after_load"],
+ "default_local_dir": st.session_state["checkpoint_path"],
+ "val_before_train": False,
+ "sync_freq": st.session_state["sync_iteration_interval"],
+ "max_actor_ckpt_to_keep": st.session_state["max_actor_ckpt_to_keep"],
+ "max_critic_ckpt_to_keep": st.session_state["max_critic_ckpt_to_keep"],
+ },
+ }
+ return trainer_config
+
+ def generate_config(self):
+ trainer_nnodes = (
+ st.session_state["node_num"]
+ - st.session_state["engine_num"]
+ * st.session_state["tensor_parallel_size"]
+ // st.session_state["gpu_per_node"]
+ )
+ if st.session_state["node_num"] == 1:
+ trainer_n_gpus_per_node = (
+ st.session_state["gpu_per_node"]
+ - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"]
+ )
else:
- raise ValueError(f"Invalid trainer type: {trainer_type}")
+ trainer_n_gpus_per_node = st.session_state["gpu_per_node"]
- if st.button("Generate Config", disabled=self.unfinished_flag):
+ db_url = (
+ st.session_state["db_url"]
+ if st.session_state["db_url"].strip()
+ else f"sqlite:///{os.path.join(st.session_state['checkpoint_path'], '.cache', st.session_state['project'], st.session_state['exp_name'])}/data.db"
+ )
+ sft_storage_type = (
+ StorageType.SQL.value
+ if "://" in st.session_state["sft_warmup_dataset_path"]
+ else StorageType.FILE.value
+ ) # TODO
+ if st.session_state["trainer_type"] == "verl":
+ trainer_config = self._generate_verl_config(
+ trainer_nnodes=trainer_nnodes, trainer_n_gpus_per_node=trainer_n_gpus_per_node
+ )
+ else:
+ raise ValueError(f"Invalid trainer type: {st.session_state['trainer_type']}")
+
+ if len(self.unfinished_fields) > 0:
+ disable_generate = True
+ help_messages = (
+ f"Please check following fields: `{'`, `'.join(self.unfinished_fields)}`"
+ )
+ else:
+ disable_generate = False
+ help_messages = None
+ if st.button(
+ "Generate Config",
+ disabled=disable_generate,
+ help=help_messages,
+ ):
config = {
"data": {
- "total_epochs": total_epoch,
- "batch_size": batch_size_per_gpu * gpu_num,
- "dataset_path": dataset_path,
- "default_workflow_type": default_workflow_type,
- "default_reward_fn_type": default_reward_fn_type,
- "train_split": train_split,
- "eval_split": eval_split,
+ "total_epochs": st.session_state["total_epochs"],
+ "batch_size": st.session_state["task_num_per_batch"],
+ "dataset_path": st.session_state["dataset_path"],
+ "default_workflow_type": st.session_state["default_workflow_type"],
+ "default_reward_fn_type": st.session_state["default_reward_fn_type"],
+ "train_split": st.session_state["train_split"],
+ "eval_split": st.session_state["eval_split"],
"format_config": {
- "prompt_key": prompt_key,
- "response_key": response_key,
+ "prompt_key": st.session_state["prompt_key"],
+ "response_key": st.session_state["response_key"],
},
},
"model": {
- "model_path": model_path,
- "max_prompt_tokens": max_prompt_tokens,
- "max_response_tokens": max_response_tokens,
- "checkpoint_path": checkpoint_path,
+ "model_path": st.session_state["model_path"],
+ "max_prompt_tokens": st.session_state["max_prompt_tokens"],
+ "max_response_tokens": st.session_state["max_response_tokens"],
+ "checkpoint_path": st.session_state["checkpoint_path"],
},
"cluster": {
- "node_num": node_num,
- "gpu_per_node": gpu_per_node,
+ "node_num": st.session_state["node_num"],
+ "gpu_per_node": st.session_state["gpu_per_node"],
},
"buffer": {
- "storage_type": storage_type,
"db_url": db_url,
- "read_batch_size": batch_size_per_gpu * gpu_num * repeat_times,
- "max_retry_times": max_retry_times,
- "max_retry_interval": max_retry_interval,
+ "read_batch_size": st.session_state["task_num_per_batch"]
+ * st.session_state["repeat_times"],
+ "max_retry_times": st.session_state["max_retry_times"],
+ "max_retry_interval": st.session_state["max_retry_interval"],
+ "train_dataset": {
+ "name": "experience_buffer", # TODO
+ "storage_type": st.session_state["storage_type"],
+ "algorithm_type": st.session_state["algorithm_type"],
+ "path": db_url,
+ },
+ "sft_warmup_dataset": {
+ "name": "sft_warmup_dataset",
+ "storage_type": sft_storage_type,
+ "algorithm_type": AlgorithmType.SFT.value,
+ "path": st.session_state["sft_warmup_dataset_path"],
+ },
},
"explorer": {
- "engine_type": engine_type,
- "engine_num": engine_num,
- "runner_num": runner_num,
- "tensor_parallel_size": tensor_parallel_size,
- "enable_prefix_caching": enable_prefix_caching,
- "enforce_eager": enforce_eager,
- "dtype": dtype,
- "temperature": temperature,
- "top_p": top_p,
- "top_k": top_k,
- "seed": seed,
- "logprobs": logprobs,
- "repeat_times": repeat_times,
- "backend": backend,
- "max_pending_requests": max_pending_requests,
- "max_waiting_steps": max_waiting_steps,
+ "engine_type": st.session_state["engine_type"],
+ "engine_num": st.session_state["engine_num"],
+ "runner_num": st.session_state["runner_num"],
+ "tensor_parallel_size": st.session_state["tensor_parallel_size"],
+ "enable_prefix_caching": st.session_state["enable_prefix_caching"],
+ "enforce_eager": st.session_state["enforce_eager"],
+ "dtype": st.session_state["dtype"],
+ "temperature": st.session_state["temperature"],
+ "top_p": st.session_state["top_p"],
+ "top_k": st.session_state["top_k"],
+ "seed": st.session_state["seed"],
+ "logprobs": st.session_state["logprobs"],
+ "repeat_times": st.session_state["repeat_times"],
+ "backend": st.session_state["backend"],
+ "max_pending_requests": st.session_state["max_pending_requests"],
+ "max_waiting_steps": st.session_state["max_waiting_steps"],
},
"synchronizer": {
- "sync_method": sync_method,
- "sync_iteration_interval": sync_iteration_interval,
+ "sync_method": st.session_state["sync_method"],
+ "sync_iteration_interval": st.session_state["sync_iteration_interval"],
},
"trainer": {
- "trainer_type": trainer_type,
- "trainer_config_path": trainer_config_path,
- "sft_warmup_iteration": sft_warmup_iteration,
- "eval_interval": eval_interval,
+ "trainer_type": st.session_state["trainer_type"],
+ "algorithm_type": st.session_state["algorithm_type"],
+ "trainer_config": trainer_config,
+ "sft_warmup_iteration": st.session_state["sft_warmup_iteration"],
+ "eval_interval": st.session_state["eval_interval"],
},
"monitor": {
- "project": project,
- "name": name,
+ "project": st.session_state["project"],
+ "name": st.session_state["exp_name"],
+ "monitor_type": st.session_state["monitor_type"],
},
}
st.header("Generated Config File")
- st.subheader("Overall Config File")
+ st.subheader("Config File")
yaml_config = yaml.dump(config, allow_unicode=True, sort_keys=False)
st.code(yaml_config, language="yaml")
- st.subheader("Trainer Config File")
- trainer_config = yaml.dump(trainer_config, allow_unicode=True, sort_keys=False)
- st.code(trainer_config, language="yaml")
-
- def main(self):
- mode = st.pills(
- "Select Mode",
- options=["Beginer Mode", "Expert Mode"],
- default="Expert Mode",
- label_visibility="collapsed",
- )
- if mode == "Beginer Mode":
- self.beginer_mode()
- else:
- self.expert_mode()
if __name__ == "__main__":
config_manager = ConfigManager()
- config_manager.main()
diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py
index 94f5939f53..0273fd0544 100644
--- a/trinity/trainer/verl/fsdp_workers.py
+++ b/trinity/trainer/verl/fsdp_workers.py
@@ -928,17 +928,6 @@ def _build_critic_model_optimizer(self, config):
)
critic_model_config.num_labels = 1
- use_remove_padding = config.model.get("use_remove_padding", False)
- if use_remove_padding:
- from verl.models.registry import check_model_support_rmpad
-
- check_model_support_rmpad(critic_model_config.model_type)
-
- if use_remove_padding and self.ulysses_sequence_parallel_size > 1:
- from verl.models.transformers.monkey_patch import apply_monkey_patch
-
- apply_monkey_patch(critic_model_config, verbose=True)
-
init_context = get_init_weight_context_manager(
use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh
)