Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example.


## Monitor

```yaml
monitor:
project: "Trinity-RFT-countdown"
name: "qwen2.5-1.5B-countdown"
```

- `monitor.project`: The project name. It must be set manually.
- `monitor.name`: The name of the experiment. It must be set manually.


## Monitor

```yaml
Expand Down Expand Up @@ -363,7 +375,7 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
resume_from_path: ""
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand All @@ -383,8 +395,9 @@ trainer:
- `actor_rollout_ref.actor.grad_clip`: Gradient clip for actor model training.
- `actor_rollout_ref.actor.clip_ratio`: Used for compute policy loss.
- `actor_rollout_ref.actor.entropy_coeff`: Used for compute policy loss.
- `actor_rollout_ref.actor.use_kl_loss`: True for GRPO.
- `actor_rollout_ref.actor.kl_loss_coef`: Used for GRPO, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
- `actor_rollout_ref.actor.use_kl_loss`: Whether to enable kl loss.
- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss.
- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
- `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size.
- `actor_rollout_ref.actor.alg_type`: Used for OPMD, optional value is `ppo`, `opmd` or `pairwise_opmd`.
- `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy.
Expand Down
1 change: 0 additions & 1 deletion examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ buffer:
train_dataset:
name: dpo_buffer
storage_type: file
algorithm_type: dpo
path: '/PATH/TO/DATASET/'
kwargs:
prompt_type: plaintext # plaintext/messages
Expand Down
1 change: 0 additions & 1 deletion examples/dpo_humanlike/train_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ trainer:
save_freq: 30
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: alfworld_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///alfworld.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_alfworld/train_alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
2 changes: 0 additions & 2 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///gsm8k.db'
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
# algorithm_type: sft
# path: '/PATH/TO/WARMUP_DATA/'
# kwargs:
# prompt_type: plaintext
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_gsm8k/train_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ buffer:
train_dataset:
name: math_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////math.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_math/train_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: sciworld_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///sciworld.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_sciworld/train_sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_webshop/train_webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: webshop_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///webshop.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
algorithm_type: opmd
path: 'sqlite:///gsm8k_opmd.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/opmd_gsm8k/train_opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ buffer:
train_dataset:
name: countdown_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////countdown.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_countdown/train_countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"math_verify",
"ninja",
"fire",
"streamlit",
"flask",
"requests",
"tensorboard",
Expand Down
34 changes: 21 additions & 13 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class DatasetConfig:

name: str
storage_type: StorageType
algorithm_type: AlgorithmType
algorithm_type: AlgorithmType = AlgorithmType.PPO
path: Optional[str] = None
kwargs: Dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -143,7 +143,7 @@ class ExplorerConfig:
# For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num`
runner_num: int = 1

# repeat each task for `repeat_times` times (for GPRO-like algrorithms)
# repeat each task for `repeat_times` times (for GPRO-like algorithms)
repeat_times: int = 1

# for rollout tokneize
Expand Down Expand Up @@ -177,7 +177,7 @@ class TrainerConfig:
trainer_config_path: str = ""
eval_interval: int = 100
enable_preview: bool = True # enable rollout preview in wandb
trainer_config: Any = None
trainer_config: Any = field(default_factory=dict)

# train algorithm
algorithm_type: AlgorithmType = AlgorithmType.PPO
Expand Down Expand Up @@ -266,21 +266,29 @@ def _check_buffer(self) -> None:
else:
if self.buffer.train_dataset is None:
raise ValueError("buffer.train_dataset is required when mode is not 'both'")
if self.buffer.train_dataset.algorithm_type != self.trainer.algorithm_type:
raise ValueError(
f"buffer.train_dataset.algorithm_type ({self.buffer.train_dataset.algorithm_type}) "
f"is not consistent with trainer.algorithm_type ({self.trainer.algorithm_type})"
)
self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type
if self.buffer.sft_warmup_dataset is not None:
self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times

def check_and_update(self) -> None:
"""Check and update the config."""
if self.trainer.trainer_type == "verl":
from trinity.common.verl_config import load_config

if not os.path.isfile(self.trainer.trainer_config_path):
raise ValueError(f"Invalid trainer config path: {self.trainer.trainer_config_path}")
self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
if self.trainer.trainer_config:
from trinity.common.verl_config import veRLConfig

trainer_config_schema = OmegaConf.structured(veRLConfig)
trainer_config = OmegaConf.merge(trainer_config_schema, self.trainer.trainer_config)
self.trainer.trainer_config = OmegaConf.to_object(trainer_config)
else:
if os.path.isfile(self.trainer.trainer_config_path):
from trinity.common.verl_config import load_config

self.trainer.trainer_config = load_config(self.trainer.trainer_config_path)
else:
raise ValueError(
f"Invalid trainer config path: {self.trainer.trainer_config_path}"
)
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")

Expand Down
10 changes: 8 additions & 2 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ class Actor:
ppo_epochs: int = 1
shuffle: bool = False
ulysses_sequence_parallel_size: int = 1
checkpoint: Checkpoint = field(default_factory=Checkpoint)
optim: Optim = field(default_factory=Optim)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
alg_type: str = "ppo" # ppo / opmd / pairwise_opmd
tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
use_uid: bool = False # True / False, applicable to pairwise_opmd
checkpoint: Checkpoint = field(default_factory=Checkpoint)


@dataclass
Expand Down Expand Up @@ -205,13 +205,17 @@ class CustomRewardFunction:
class KL_Ctrl:
type: str = "fixed"
kl_coef: float = 0.001
horizon: float = 10000
target_kl: float = 0.1


@dataclass
class Algorithm:
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
norm_adv_by_std_in_grpo: bool = True
use_kl_in_reward: bool = False
kl_penalty: str = "kl"
kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)

Expand Down Expand Up @@ -300,7 +304,9 @@ def synchronize_config(self, config: Config) -> None:
self.actor_rollout_ref.rollout.temperature = config.explorer.temperature
self.actor_rollout_ref.rollout.n = config.explorer.repeat_times
batch_size_per_gpu = self.buffer.read_batch_size // world_size
self.actor_rollout_ref.actor.alg_type = config.trainer.algorithm_type.value
self.actor_rollout_ref.actor.alg_type = (
config.trainer.algorithm_type.value
) # TODO: refactor `alg_type`
print(f"using algorithm type: {self.actor_rollout_ref.actor.alg_type}")

if self.actor_rollout_ref.actor.alg_type == "dpo": # for DPO
Expand Down
Loading