Skip to content

Commit aeabfe5

Browse files
authored
merge verl 0.4.0 (#79)
1 parent 0e56607 commit aeabfe5

File tree

7 files changed

+499
-710
lines changed

7 files changed

+499
-710
lines changed

examples/dpo_humanlike/train_dpo.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ actor_rollout_ref:
2626
min_lr_ratio: 0.1 # only useful for warmup with cosine
2727
warmup_style: cosine # select from constant/cosine
2828
total_training_steps: 783 #
29-
beta1: 0.9
30-
beta2: 0.95
29+
betas: [0.9, 0.95]
3130
fsdp_config:
3231
wrap_policy:
3332
# transformer_layer_cls_to_wrap: None

examples/opmd_gsm8k/train_opmd_gsm8k.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# entropy_coeff: default to 0.0 for now
1616
#
1717
# optimizer:
18-
# beta1, beta2: 0.0, 0.95 # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift
19-
# lr: set smaller to account for beta1 = 0.0
18+
# betas: [0.0, 0.95] # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift
19+
# lr: set smaller to account for betas[0] = 0.0
2020
#
2121
# misc:
2222
# adv_estimator: grpo # merely to disable critic model, doesn't affect adv compute when algorithm_type is opmd
@@ -50,8 +50,7 @@ actor_rollout_ref:
5050
# min_lr_ratio: null # only useful for warmup with cosine
5151
warmup_style: constant # select from constant/cosine
5252
total_training_steps: -1 # must be override by program
53-
beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
54-
beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
53+
betas: [0.0, 0.95] # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
5554
fsdp_config:
5655
wrap_policy:
5756
# transformer_layer_cls_to_wrap: None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
]
2222
requires-python = ">=3.10"
2323
dependencies = [
24-
"verl==0.3.0.post1",
24+
"verl==0.4.0",
2525
"ray[default]>=2.45.0",
2626
"vllm==0.8.5.post1",
2727
"tensordict==0.6.2",

trinity/common/verl_config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ class Optim:
3333
min_lr_ratio: Optional[float] = 0.0
3434
warmup_style: str = "constant"
3535
total_training_steps: int = -1
36-
beta1: float = 0.9
37-
beta2: float = 0.999
36+
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
3837

3938

4039
@dataclass
@@ -82,6 +81,7 @@ class Actor:
8281
tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
8382
opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
8483
use_uid: bool = False # True / False, applicable to pairwise_opmd
84+
loss_agg_mode: str = "token-mean" # do not set
8585

8686

8787
@dataclass
@@ -99,12 +99,20 @@ class _ValKwargs:
9999
do_sample: bool = False
100100

101101

102+
@dataclass
103+
class _MultiTurn:
104+
enable: bool = False
105+
106+
102107
@dataclass
103108
class Rollout:
104109
# do not set
105110
val_kwargs: _ValKwargs = field(default_factory=_ValKwargs)
111+
multi_turn: _MultiTurn = field(default_factory=_MultiTurn)
106112
temperature: float = 1.0
107113
n: int = 1 # > 1 for grpo
114+
log_prob_micro_batch_size: Optional[int] = None
115+
log_prob_micro_batch_size_per_gpu: int = 1
108116

109117

110118
@dataclass
@@ -148,6 +156,7 @@ class Critic:
148156
cliprange_value: float = 0.0
149157
checkpoint: Checkpoint = field(default_factory=Checkpoint)
150158
rollout_n: int = 1
159+
loss_agg_mode: str = "token-mean"
151160

152161

153162
@dataclass

trinity/trainer/verl/dp_actor.py

Lines changed: 42 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
# Copyright 2023-2024 SGLang Team
3+
# Copyright 2025 ModelBest Inc. and/or its affiliates
24
#
35
# Licensed under the Apache License, Version 2.0 (the "License");
46
# you may not use this file except in compliance with the License.
@@ -12,49 +14,42 @@
1214
# See the License for the specific language governing permissions and
1315
# limitations under the License.
1416
"""
15-
Modified from dp_actor.py
17+
Single Process Actor.
18+
Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/actor/dp_actor.py
1619
"""
1720

1821
import itertools
19-
from typing import Tuple
22+
import logging
23+
import os
2024

2125
import torch
22-
import verl.utils.torch_functional as verl_F
23-
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
2426
from torch import nn
25-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2627
from verl import DataProto
28+
from verl.utils.debug import GPUMemoryLogger
29+
from verl.utils.device import get_torch_device
2730
from verl.utils.py_functional import append_to_dict
2831
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
29-
from verl.utils.torch_functional import logprobs_from_logits
30-
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
31-
from verl.workers.actor import BasePPOActor
32+
from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor
3233

3334
from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
35+
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
3436
from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
3537
from trinity.algorithm.utils import prefix_metrics
3638
from trinity.common.config import AlgorithmConfig
3739

3840
__all__ = ["DataParallelPPOActor"]
3941

42+
logger = logging.getLogger(__file__)
43+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
4044

41-
class DataParallelPPOActor(BasePPOActor):
45+
46+
class DataParallelPPOActor(DPActor):
4247
def __init__(
43-
self,
44-
config,
45-
actor_module: nn.Module,
46-
actor_optimizer: torch.optim.Optimizer = None,
48+
self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None
4749
):
4850
"""When optimizer is None, it is Reference Policy"""
49-
super().__init__(config)
50-
self.actor_module = actor_module
51-
self.actor_optimizer = actor_optimizer
52-
self.use_remove_padding = self.config.get("use_remove_padding", False)
53-
print(f"Actor use_remove_padding={self.use_remove_padding}")
54-
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
55-
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
56-
57-
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
51+
super().__init__(config, actor_module, actor_optimizer)
52+
5853
self.policy_loss_fn = None
5954
self.kl_loss_fn = None
6055
self.entropy_loss_fn = None
@@ -68,150 +63,8 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig):
6863
**algorithm_config.entropy_loss_fn_args
6964
)
7065

71-
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
72-
"""
73-
Returns:
74-
entropy: # (bs, response_len)
75-
log_probs: # (bs, response_len)
76-
"""
77-
response_length = micro_batch["responses"].size(-1)
78-
multi_modal_inputs = {}
79-
if "multi_modal_inputs" in micro_batch:
80-
for key in micro_batch["multi_modal_inputs"][0].keys():
81-
multi_modal_inputs[key] = torch.cat(
82-
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
83-
)
84-
85-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
86-
input_ids = micro_batch["input_ids"]
87-
batch_size, seqlen = input_ids.shape
88-
attention_mask = micro_batch["attention_mask"]
89-
position_ids = micro_batch["position_ids"]
90-
if position_ids.dim() == 3: # qwen2vl mrope
91-
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
92-
93-
if self.use_remove_padding:
94-
input_ids_rmpad, indices, *_ = unpad_input(
95-
input_ids.unsqueeze(-1), attention_mask
96-
) # input_ids_rmpad (total_nnz, ...)
97-
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
98-
99-
# unpad the position_ids to align the rotary
100-
if position_ids.dim() == 3:
101-
position_ids_rmpad = (
102-
index_first_axis(
103-
rearrange(position_ids, "c b s ... -> (b s) c ..."), indices
104-
)
105-
.transpose(0, 1)
106-
.unsqueeze(1)
107-
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
108-
else:
109-
position_ids_rmpad = index_first_axis(
110-
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
111-
).transpose(0, 1)
112-
113-
# for compute the log_prob
114-
input_ids_rmpad_rolled = torch.roll(
115-
input_ids_rmpad, shifts=-1, dims=1
116-
) # (1, total_nnz)
117-
118-
# pad and slice the inputs if sp > 1
119-
if self.use_ulysses_sp:
120-
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
121-
input_ids_rmpad,
122-
position_ids_rmpad,
123-
sp_size=self.ulysses_sequence_parallel_size,
124-
)
125-
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
126-
input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size
127-
)
128-
129-
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(
130-
0
131-
) # ((total_nnz / sp) + pad)
132-
133-
# only pass input_ids and position_ids to enable flash_attn_varlen
134-
output = self.actor_module(
135-
input_ids=input_ids_rmpad,
136-
attention_mask=None,
137-
position_ids=position_ids_rmpad,
138-
**multi_modal_inputs,
139-
use_cache=False,
140-
) # prevent model thinks we are generating
141-
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
142-
143-
logits_rmpad.div_(temperature)
144-
145-
# compute entropy
146-
entropy_rmpad = self.compute_entropy_from_logits(
147-
logits_rmpad
148-
) # ((total_nnz / sp) + pad)
149-
150-
# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
151-
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
152-
153-
# gather log_prob if sp > 1
154-
if self.use_ulysses_sp:
155-
# gather and unpad for the ulysses sp
156-
log_probs = gather_outpus_and_unpad(
157-
log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size
158-
)
159-
entropy_rmpad = gather_outpus_and_unpad(
160-
entropy_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
161-
)
162-
# pad back to (bsz, seqlen)
163-
full_entropy = pad_input(
164-
hidden_states=entropy_rmpad.unsqueeze(-1),
165-
indices=indices,
166-
batch=batch_size,
167-
seqlen=seqlen,
168-
)
169-
full_log_probs = pad_input(
170-
hidden_states=log_probs.unsqueeze(-1),
171-
indices=indices,
172-
batch=batch_size,
173-
seqlen=seqlen,
174-
)
175-
176-
# only return response part:
177-
entropy = full_entropy.squeeze(-1)[
178-
:, -response_length - 1 : -1
179-
] # (bsz, response_length)
180-
log_probs = full_log_probs.squeeze(-1)[
181-
:, -response_length - 1 : -1
182-
] # (bsz, response_length)
183-
184-
else: # not using rmpad and no ulysses sp
185-
output = self.actor_module(
186-
input_ids=input_ids,
187-
attention_mask=attention_mask,
188-
position_ids=position_ids,
189-
**multi_modal_inputs,
190-
use_cache=False,
191-
) # prevent model thinks we are generating
192-
logits = output.logits
193-
logits.div_(temperature)
194-
logits = logits[
195-
:, -response_length - 1 : -1, :
196-
] # (bsz, response_length, vocab_size)
197-
log_probs = logprobs_from_logits(logits, micro_batch["responses"])
198-
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
199-
200-
return entropy, log_probs
201-
202-
def _optimizer_step(self):
203-
assert self.config.grad_clip is not None
204-
205-
if isinstance(self.actor_module, FSDP):
206-
grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
207-
else:
208-
grad_norm = torch.nn.utils.clip_grad_norm_(
209-
self.actor_module.parameters(), max_norm=self.config.grad_clip
210-
)
211-
self.actor_optimizer.step()
212-
return grad_norm
213-
214-
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
66+
@GPUMemoryLogger(role="dp actor", logger=logger)
67+
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
21568
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
21669
21770
Args:
@@ -235,7 +88,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
23588
micro_batch_size = data.meta_info["micro_batch_size"]
23689
temperature = data.meta_info[
23790
"temperature"
238-
] # temperature must be in the data.meta_info to avoid slient error
91+
] # temperature must be in the data.meta_info to avoid silent error
23992
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
24093

24194
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
@@ -258,30 +111,40 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
258111
micro_batches = batch.split(micro_batch_size)
259112

260113
log_probs_lst = []
114+
entropy_lst = []
261115
for micro_batch in micro_batches:
262116
if isinstance(micro_batch, DataProto):
263117
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
264-
265118
with torch.no_grad():
266-
_, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
119+
entropy, log_probs = self._forward_micro_batch(
120+
micro_batch, temperature=temperature, calculate_entropy=calculate_entropy
121+
)
267122
log_probs_lst.append(log_probs)
268-
log_probs = torch.concat(log_probs_lst, dim=0)
123+
if calculate_entropy:
124+
entropy_lst.append(entropy)
269125

126+
log_probs = torch.concat(log_probs_lst, dim=0)
127+
entropys = None
128+
if calculate_entropy:
129+
entropys = torch.concat(entropy_lst, dim=0)
270130
if use_dynamic_bsz:
271131
indices = list(itertools.chain.from_iterable(indices))
272132
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
273133
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
274134
log_probs = log_probs[revert_indices]
135+
if calculate_entropy:
136+
entropys = entropys[revert_indices] # type: ignore
275137

276-
return log_probs
138+
return log_probs, entropys
277139

278-
def update_policy(self, data: DataProto): # noqa: C901
140+
@GPUMemoryLogger(role="dp actor", logger=logger)
141+
def update_policy(self, data: DataProto):
279142
# make sure we are in training mode
280143
self.actor_module.train()
281144

282145
temperature = data.meta_info[
283146
"temperature"
284-
] # temperature must be in the data.meta_info to avoid slient error
147+
] # temperature must be in the data.meta_info to avoid silent error
285148
select_keys = [
286149
"input_ids",
287150
"position_ids",
@@ -351,12 +214,12 @@ def update_policy(self, data: DataProto): # noqa: C901
351214
# Support all hardwares
352215
if isinstance(data, DataProto):
353216
data = {
354-
**data.batch.to(torch.cuda.current_device()),
217+
**data.batch.to(get_torch_device().current_device()),
355218
**data.non_tensor_batch,
356219
}
357220
else:
358221
data = data.to(
359-
torch.cuda.current_device()
222+
get_torch_device().current_device()
360223
) # actor device is cpu when using offload
361224
responses = data["responses"]
362225
response_length = responses.size(1)
@@ -365,8 +228,11 @@ def update_policy(self, data: DataProto): # noqa: C901
365228
assert response_mask.shape == attention_mask[:, -response_length:].shape
366229

367230
# all return: (bsz, response_length)
231+
calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
368232
entropy, log_prob = self._forward_micro_batch(
369-
micro_batch=data, temperature=temperature
233+
micro_batch=data,
234+
temperature=temperature,
235+
calculate_entropy=calculate_entropy,
370236
)
371237

372238
kwargs = {

0 commit comments

Comments
 (0)