Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]
requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"verl==0.7.0",
"ray[default]>=2.50.0",
"tensordict",
"wandb",
Expand Down
15 changes: 10 additions & 5 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class LoRAConfig:
lora_alpha: int = 32
lora_dtype: str = "auto"
target_modules: str = "all-linear"
exclude_modules: Optional[str] = None
is_dummy: bool = False # DO NOT SET, automatically set


@Experimental
Expand Down Expand Up @@ -1356,16 +1358,19 @@ def _check_explorer(self) -> None:
self.explorer.rollout_model.enable_lora = True
if len(self.model.lora_configs) > 1:
raise ValueError("Only one lora adapter is supported for now.")
if self.model.lora_configs[0].path is None:
lora_config = self.model.lora_configs[0]
if lora_config.path is None:
logger.info("Creating dummy lora, since no lora_path is provided.")
lora_path = create_dummy_lora(
model_path=self.model.model_path,
checkpoint_job_dir=self.checkpoint_job_dir,
lora_rank=self.model.lora_configs[0].lora_rank,
lora_alpha=self.model.lora_configs[0].lora_alpha,
target_modules=self.model.lora_configs[0].target_modules,
lora_rank=lora_config.lora_rank,
lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.target_modules,
exclude_modules=lora_config.exclude_modules,
)
self.model.lora_configs[0].path = lora_path
lora_config.path = lora_path
lora_config.is_dummy = True
self.explorer.rollout_model.lora_modules = [
{
"lora_int_id": i + 1,
Expand Down
62 changes: 55 additions & 7 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Any, Dict, List, Optional

from omegaconf import OmegaConf
from verl.workers.config import PolicyLossConfig, RouterReplayConfig

from trinity.algorithm import ALGORITHM_TYPE
from trinity.common.config import Config, SynchronizerConfig, set_if_none
from trinity.common.constants import EXPLORER_NAME
from trinity.utils.log import get_logger
Expand Down Expand Up @@ -41,6 +43,8 @@ class ActorModel:
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
lora_alpha: int = 32
target_modules: Optional[str] = "all-linear"
exclude_modules: Optional[str] = None
lora_adapter_path: Optional[str] = None

# rope configs
rope_scaling: Optional[dict] = None
Expand All @@ -51,14 +55,15 @@ class ActorModel:
class Optim:
# For actor, most fields are set in algorithm.optimizer
# For critic, you can set trainer_config.critic.optim
optimizer: str = "AdamW"
optimizer_impl: str = "torch.optim"
lr: float = 1e-6
lr_warmup_steps: int = -1
lr_warmup_steps_ratio: float = 0.0
min_lr_ratio: Optional[float] = 0.0
warmup_style: str = "constant"
total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
optimizer: str = "adam"
clip_grad: float = 1.0
lr_warmup_init: float = 0.0
lr_decay_steps: Optional[int] = None
Expand All @@ -69,6 +74,7 @@ class Optim:
lr_wsd_decay_style: str = "exponential"
lr_wsd_decay_steps: Optional[int] = None
use_checkpoint_opt_param_scheduler: bool = False
override_optimizer_config: Optional[dict] = None


@dataclass
Expand All @@ -78,6 +84,7 @@ class WrapPolicy:

@dataclass
class FSDPConfig:
_target_: str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET
param_offload: bool = False
optimizer_offload: bool = False
offload_policy: bool = False
Expand All @@ -92,7 +99,7 @@ class FSDPConfig:
class Checkpoint:
load_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
save_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
async_save: bool = False # do not set, async save has bug in verl megatron training
async_save: bool = False # TODO: testing async save


@dataclass
Expand Down Expand Up @@ -124,6 +131,8 @@ class MegatronConfig:
default_factory=OverrideTransformerConfig
)
use_mbridge: bool = False
dtype: str = "bfloat16"
use_remove_padding: bool = True


@dataclass
Expand Down Expand Up @@ -157,6 +166,9 @@ class Actor:
profile: ProfileConfig = field(default_factory=ProfileConfig)
data_loader_seed: Optional[int] = None
load_weight: bool = True
policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)
profiler: dict = field(default_factory=dict)
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
# do not set
loss_agg_mode: str = "token-mean"
clip_ratio: float = 0.2
Expand All @@ -182,6 +194,8 @@ class Ref:
megatron: MegatronConfig = field(default_factory=MegatronConfig)
profile: ProfileConfig = field(default_factory=ProfileConfig)
load_weight: bool = True
profiler: dict = field(default_factory=dict)
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)


@dataclass
Expand Down Expand Up @@ -214,6 +228,7 @@ class ActorRolloutRef:
actor: Actor = field(default_factory=Actor)
ref: Ref = field(default_factory=Ref)
rollout: Rollout = field(default_factory=Rollout)
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
synchronizer: Optional[SynchronizerConfig] = None
explorer_name: str = EXPLORER_NAME

Expand All @@ -232,6 +247,7 @@ class CriticModel:

@dataclass
class Critic:
enable: bool = False
strategy: Optional[str] = None
optim: Optim = field(default_factory=Optim)
model: CriticModel = field(default_factory=CriticModel)
Expand All @@ -255,7 +271,9 @@ class Critic:
profile: ProfileConfig = field(default_factory=ProfileConfig)
data_loader_seed: Optional[int] = None
load_weight: bool = True
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
ray_namespace: str = "" # automatically generated
profiler: dict = field(default_factory=dict)


@dataclass
Expand All @@ -278,6 +296,7 @@ class RewardModel:
use_dynamic_bsz: bool = False
forward_max_token_len_per_gpu: int = 0
reward_manager: str = "naive"
use_reward_loop: bool = True


@dataclass
Expand All @@ -294,8 +313,24 @@ class KL_Ctrl:
target_kl: float = 0.1


@dataclass
class RolloutCorrection:
rollout_is: Optional[str] = None
rollout_is_threshold: float = 2.0
rollout_rs: Optional[str] = None
rollout_rs_threshold: Optional[float] = None
rollout_rs_threshold_lower: Optional[float] = None
rollout_token_veto_threshold: Optional[float] = None
# Because rollout and training in Trinity runs separately,
# rollout_is_batch_normalize is default to True
bypass_mode: bool = True
loss_type: str = "ppo_clip"
rollout_is_batch_normalize: bool = False


@dataclass
class Algorithm:
rollout_correction: RolloutCorrection = field(default_factory=RolloutCorrection)
# ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
# and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
# if they are really needed (e.g., for GAE advantage/returns computation)
Expand Down Expand Up @@ -349,6 +384,7 @@ class veRLConfig:
custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction)
algorithm: Algorithm = field(default_factory=Algorithm)
trainer: Trainer = field(default_factory=Trainer)
global_profiler: dict = field(default_factory=dict)
synchronizer: Optional[SynchronizerConfig] = None
enable_preview: bool = True

Expand Down Expand Up @@ -423,8 +459,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
) # kept to pass RayPPOTrainer._validate_config

self.synchronizer = config.synchronizer
self.actor_rollout_ref.nccl_timeout = config.synchronizer.sync_timeout
self.actor_rollout_ref.synchronizer = config.synchronizer
self.actor_rollout_ref.explorer_name = config.explorer.name
algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type)
self.critic.enable = algorithm.use_critic
self.critic.nccl_timeout = config.synchronizer.sync_timeout
self.critic.ray_namespace = config.synchronizer.ray_namespace

# Actor / Rollout Config
Expand Down Expand Up @@ -539,11 +579,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901

# LoRA related config
if config.model.lora_configs is not None:
self.actor_rollout_ref.model.lora_rank = config.model.lora_configs[0].lora_rank
self.actor_rollout_ref.model.lora_alpha = config.model.lora_configs[0].lora_alpha
self.actor_rollout_ref.model.target_modules = config.model.lora_configs[
0
].target_modules
lora_config = config.model.lora_configs[0]
actor_model_config = self.actor_rollout_ref.model
for attr in ["lora_rank", "lora_alpha", "target_modules", "exclude_modules"]:
setattr(actor_model_config, attr, getattr(lora_config, attr))
if not lora_config.is_dummy:
actor_model_config.lora_adapter_path = lora_config.path
if self.actor_rollout_ref.actor.strategy not in ["fsdp", "fsdp2"]:
logger.warning(
f"Lora is only supported for fsdp and fsdp2, but got {self.actor_rollout_ref.actor.strategy} instead, changed to fsdp."
Expand All @@ -559,6 +600,13 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
for field_name in config.algorithm.optimizer.__dataclass_fields__:
field_value = getattr(config.algorithm.optimizer, field_name)
if field_name == "optimizer_type":
if config.trainer.trainer_strategy.startswith("fsdp"):
optim_map = {
"adam": "AdamW",
"adamw": "AdamW",
"sgd": "SGD",
}
field_value = optim_map.get(field_value, field_value)
setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value)
elif hasattr(self.actor_rollout_ref.actor.optim, field_name):
setattr(self.actor_rollout_ref.actor.optim, field_name, field_value)
Expand Down
45 changes: 36 additions & 9 deletions trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
"""
Single Process Actor.
Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/actor/dp_actor.py
Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/dp_actor.py
"""

import logging
Expand Down Expand Up @@ -67,9 +67,8 @@ def update_policy(self, data: DataProto): # noqa: C901
# make sure we are in training mode
self.actor_module.train()

temperature = data.meta_info[
"temperature"
] # temperature must be in the data.meta_info to avoid silent error
# temperature must be in the data.meta_info to avoid silent error
temperature = data.meta_info["temperature"]
select_keys = [
"input_ids",
"position_ids",
Expand All @@ -80,13 +79,17 @@ def update_policy(self, data: DataProto): # noqa: C901
select_keys.extend(self.policy_loss_fn.select_keys)
if not isinstance(self.kl_loss_fn, DummyKLFn):
select_keys.append("ref_log_prob")
# rollout_is_weights will be used in policy loss
# rollout_log_probs is equal to old_log_prob in Trinity
select_keys = list(set(select_keys))

has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []

data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)

# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.split(self.config.ppo_mini_batch_size)

# EXPERIMENTAL: apply loss scale fix
Expand Down Expand Up @@ -119,12 +122,11 @@ def update_policy(self, data: DataProto): # noqa: C901
self.actor_optimizer.zero_grad()

for micro_batch in micro_batches:
micro_batch = micro_batch.to(get_device_id())
micro_batch_metrics = {}
model_inputs = {
**micro_batch.batch.to(get_device_id()),
**micro_batch.non_tensor_batch,
}
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

# all return: (bsz, response_length)
calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
Expand All @@ -141,6 +143,23 @@ def update_policy(self, data: DataProto): # noqa: C901
src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
)

# TODO: to be check
# Skip if using bypass_mode loss (metrics already computed in pg_metrics)
rollout_log_prob = model_inputs.get("rollout_log_probs", None)
if loss_mode != "bypass_mode" and rollout_log_prob is not None:
# Compute metrics using CURRENT policy π_θ vs π_rollout
# Tracks evolving off-policy gap as π_θ updates during mini-batch training
from verl.trainer.ppo.rollout_corr_helper import (
compute_rollout_corr_metrics_from_logprobs,
)

rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs(
log_prob=log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
)
micro_batch_metrics.update(rollout_corr_metrics)

# compute entropy loss from entropy
entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore
entropy=entropy,
Expand Down Expand Up @@ -185,7 +204,15 @@ def update_policy(self, data: DataProto): # noqa: C901

loss = policy_loss * loss_scale
micro_batch_metrics["actor/final_loss"] = loss.detach().item()
loss.backward()
if "actor/kl_loss" in micro_batch_metrics:
micro_batch_metrics["actor/kl_loss"] *= loss_scale
if "actor/pg_loss" in micro_batch_metrics:
micro_batch_metrics["actor/pg_loss"] *= loss_scale

if self.scaler is not None:
self.scaler.scale(loss).backward()
else:
loss.backward()

append_to_dict(metrics, micro_batch_metrics)

Expand Down
Loading