Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/dpo_humanlike/train_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ actor_rollout_ref:
min_lr_ratio: 0.1 # only useful for warmup with cosine
warmup_style: cosine # select from constant/cosine
total_training_steps: 783 #
beta1: 0.9
beta2: 0.95
betas: [0.9, 0.95]
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
Expand Down
7 changes: 3 additions & 4 deletions examples/opmd_gsm8k/train_opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# entropy_coeff: default to 0.0 for now
#
# optimizer:
# beta1, beta2: 0.0, 0.95 # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift
# lr: set smaller to account for beta1 = 0.0
# betas: [0.0, 0.95] # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift
# lr: set smaller to account for betas[0] = 0.0
#
# misc:
# adv_estimator: grpo # merely to disable critic model, doesn't affect adv compute when algorithm_type is opmd
Expand Down Expand Up @@ -50,8 +50,7 @@ actor_rollout_ref:
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
betas: [0.0, 0.95] # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]
requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"verl==0.4.0",
"ray[default]>=2.45.0",
"vllm==0.8.5.post1",
"tensordict==0.6.2",
Expand Down
13 changes: 11 additions & 2 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class Optim:
min_lr_ratio: Optional[float] = 0.0
warmup_style: str = "constant"
total_training_steps: int = -1
beta1: float = 0.9
beta2: float = 0.999
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])


@dataclass
Expand Down Expand Up @@ -82,6 +81,7 @@ class Actor:
tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
use_uid: bool = False # True / False, applicable to pairwise_opmd
loss_agg_mode: str = "token-mean" # do not set


@dataclass
Expand All @@ -99,12 +99,20 @@ class _ValKwargs:
do_sample: bool = False


@dataclass
class _MultiTurn:
enable: bool = False


@dataclass
class Rollout:
# do not set
val_kwargs: _ValKwargs = field(default_factory=_ValKwargs)
multi_turn: _MultiTurn = field(default_factory=_MultiTurn)
temperature: float = 1.0
n: int = 1 # > 1 for grpo
log_prob_micro_batch_size: Optional[int] = None
log_prob_micro_batch_size_per_gpu: int = 1


@dataclass
Expand Down Expand Up @@ -148,6 +156,7 @@ class Critic:
cliprange_value: float = 0.0
checkpoint: Checkpoint = field(default_factory=Checkpoint)
rollout_n: int = 1
loss_agg_mode: str = "token-mean"


@dataclass
Expand Down
218 changes: 42 additions & 176 deletions trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,49 +14,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Modified from dp_actor.py
Single Process Actor.
Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/actor/dp_actor.py
"""

import itertools
from typing import Tuple
import logging
import os

import torch
import verl.utils.torch_functional as verl_F
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
from verl.utils.debug import GPUMemoryLogger
from verl.utils.device import get_torch_device
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.workers.actor import BasePPOActor
from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor

from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig

__all__ = ["DataParallelPPOActor"]

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

class DataParallelPPOActor(BasePPOActor):

class DataParallelPPOActor(DPActor):
def __init__(
self,
config,
actor_module: nn.Module,
actor_optimizer: torch.optim.Optimizer = None,
self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None
):
"""When optimizer is None, it is Reference Policy"""
super().__init__(config)
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
self.use_remove_padding = self.config.get("use_remove_padding", False)
print(f"Actor use_remove_padding={self.use_remove_padding}")
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1

self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
super().__init__(config, actor_module, actor_optimizer)

self.policy_loss_fn = None
self.kl_loss_fn = None
self.entropy_loss_fn = None
Expand All @@ -68,150 +63,8 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig):
**algorithm_config.entropy_loss_fn_args
)

def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
response_length = micro_batch["responses"].size(-1)
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch:
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)

if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(
rearrange(position_ids, "c b s ... -> (b s) c ..."), indices
)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)

# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(
input_ids_rmpad, shifts=-1, dims=1
) # (1, total_nnz)

# pad and slice the inputs if sp > 1
if self.use_ulysses_sp:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad,
position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size
)

input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(
0
) # ((total_nnz / sp) + pad)

# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)

logits_rmpad.div_(temperature)

# compute entropy
entropy_rmpad = self.compute_entropy_from_logits(
logits_rmpad
) # ((total_nnz / sp) + pad)

# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)

# gather log_prob if sp > 1
if self.use_ulysses_sp:
# gather and unpad for the ulysses sp
log_probs = gather_outpus_and_unpad(
log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
entropy_rmpad = gather_outpus_and_unpad(
entropy_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
# pad back to (bsz, seqlen)
full_entropy = pad_input(
hidden_states=entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen,
)
full_log_probs = pad_input(
hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen,
)

# only return response part:
entropy = full_entropy.squeeze(-1)[
:, -response_length - 1 : -1
] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[
:, -response_length - 1 : -1
] # (bsz, response_length)

else: # not using rmpad and no ulysses sp
output = self.actor_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
logits = output.logits
logits.div_(temperature)
logits = logits[
:, -response_length - 1 : -1, :
] # (bsz, response_length, vocab_size)
log_probs = logprobs_from_logits(logits, micro_batch["responses"])
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)

return entropy, log_probs

def _optimizer_step(self):
assert self.config.grad_clip is not None

if isinstance(self.actor_module, FSDP):
grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.actor_module.parameters(), max_norm=self.config.grad_clip
)
self.actor_optimizer.step()
return grad_norm

def compute_log_prob(self, data: DataProto) -> torch.Tensor:
@GPUMemoryLogger(role="dp actor", logger=logger)
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids

Args:
Expand All @@ -235,7 +88,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
micro_batch_size = data.meta_info["micro_batch_size"]
temperature = data.meta_info[
"temperature"
] # temperature must be in the data.meta_info to avoid slient error
] # temperature must be in the data.meta_info to avoid silent error
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]

select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
Expand All @@ -258,30 +111,40 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
micro_batches = batch.split(micro_batch_size)

log_probs_lst = []
entropy_lst = []
for micro_batch in micro_batches:
if isinstance(micro_batch, DataProto):
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}

with torch.no_grad():
_, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
entropy, log_probs = self._forward_micro_batch(
micro_batch, temperature=temperature, calculate_entropy=calculate_entropy
)
log_probs_lst.append(log_probs)
log_probs = torch.concat(log_probs_lst, dim=0)
if calculate_entropy:
entropy_lst.append(entropy)

log_probs = torch.concat(log_probs_lst, dim=0)
entropys = None
if calculate_entropy:
entropys = torch.concat(entropy_lst, dim=0)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
log_probs = log_probs[revert_indices]
if calculate_entropy:
entropys = entropys[revert_indices] # type: ignore

return log_probs
return log_probs, entropys

def update_policy(self, data: DataProto): # noqa: C901
@GPUMemoryLogger(role="dp actor", logger=logger)
def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()

temperature = data.meta_info[
"temperature"
] # temperature must be in the data.meta_info to avoid slient error
] # temperature must be in the data.meta_info to avoid silent error
select_keys = [
"input_ids",
"position_ids",
Expand Down Expand Up @@ -356,12 +219,12 @@ def update_policy(self, data: DataProto): # noqa: C901
# Support all hardwares
if isinstance(data, DataProto):
data = {
**data.batch.to(torch.cuda.current_device()),
**data.batch.to(get_torch_device().current_device()),
**data.non_tensor_batch,
}
else:
data = data.to(
torch.cuda.current_device()
get_torch_device().current_device()
) # actor device is cpu when using offload
responses = data["responses"]
response_length = responses.size(1)
Expand All @@ -370,8 +233,11 @@ def update_policy(self, data: DataProto): # noqa: C901
assert response_mask.shape == attention_mask[:, -response_length:].shape

# all return: (bsz, response_length)
calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
entropy, log_prob = self._forward_micro_batch(
micro_batch=data, temperature=temperature
micro_batch=data,
temperature=temperature,
calculate_entropy=calculate_entropy,
)

kwargs = {
Expand Down
Loading