11# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+ # Copyright 2023-2024 SGLang Team
3+ # Copyright 2025 ModelBest Inc. and/or its affiliates
24#
35# Licensed under the Apache License, Version 2.0 (the "License");
46# you may not use this file except in compliance with the License.
1214# See the License for the specific language governing permissions and
1315# limitations under the License.
1416"""
15- Modified from dp_actor.py
17+ Single Process Actor.
18+ Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/actor/dp_actor.py
1619"""
1720
1821import itertools
19- from typing import Tuple
22+ import logging
23+ import os
2024
2125import torch
22- import verl .utils .torch_functional as verl_F
23- from flash_attn .bert_padding import index_first_axis , pad_input , rearrange , unpad_input
2426from torch import nn
25- from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
2627from verl import DataProto
28+ from verl .utils .debug import GPUMemoryLogger
29+ from verl .utils .device import get_torch_device
2730from verl .utils .py_functional import append_to_dict
2831from verl .utils .seqlen_balancing import get_reverse_idx , rearrange_micro_batches
29- from verl .utils .torch_functional import logprobs_from_logits
30- from verl .utils .ulysses import gather_outpus_and_unpad , ulysses_pad_and_slice_inputs
31- from verl .workers .actor import BasePPOActor
32+ from verl .workers .actor .dp_actor import DataParallelPPOActor as DPActor
3233
3334from trinity .algorithm import ENTROPY_LOSS_FN , KL_FN , POLICY_LOSS_FN
35+ from trinity .algorithm .entropy_loss_fn .entropy_loss_fn import DummyEntropyLossFn
3436from trinity .algorithm .kl_fn .kl_fn import DummyKLFn
3537from trinity .algorithm .utils import prefix_metrics
3638from trinity .common .config import AlgorithmConfig
3739
3840__all__ = ["DataParallelPPOActor" ]
3941
42+ logger = logging .getLogger (__file__ )
43+ logger .setLevel (os .getenv ("VERL_LOGGING_LEVEL" , "WARN" ))
4044
41- class DataParallelPPOActor (BasePPOActor ):
45+
46+ class DataParallelPPOActor (DPActor ):
4247 def __init__ (
43- self ,
44- config ,
45- actor_module : nn .Module ,
46- actor_optimizer : torch .optim .Optimizer = None ,
48+ self , config , actor_module : nn .Module , actor_optimizer : torch .optim .Optimizer = None
4749 ):
4850 """When optimizer is None, it is Reference Policy"""
49- super ().__init__ (config )
50- self .actor_module = actor_module
51- self .actor_optimizer = actor_optimizer
52- self .use_remove_padding = self .config .get ("use_remove_padding" , False )
53- print (f"Actor use_remove_padding={ self .use_remove_padding } " )
54- self .ulysses_sequence_parallel_size = self .config .ulysses_sequence_parallel_size
55- self .use_ulysses_sp = self .ulysses_sequence_parallel_size > 1
56-
57- self .compute_entropy_from_logits = torch .compile (verl_F .entropy_from_logits , dynamic = True )
51+ super ().__init__ (config , actor_module , actor_optimizer )
52+
5853 self .policy_loss_fn = None
5954 self .kl_loss_fn = None
6055 self .entropy_loss_fn = None
@@ -68,150 +63,8 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig):
6863 ** algorithm_config .entropy_loss_fn_args
6964 )
7065
71- def _forward_micro_batch (self , micro_batch , temperature ) -> Tuple [torch .Tensor , torch .Tensor ]:
72- """
73- Returns:
74- entropy: # (bs, response_len)
75- log_probs: # (bs, response_len)
76- """
77- response_length = micro_batch ["responses" ].size (- 1 )
78- multi_modal_inputs = {}
79- if "multi_modal_inputs" in micro_batch :
80- for key in micro_batch ["multi_modal_inputs" ][0 ].keys ():
81- multi_modal_inputs [key ] = torch .cat (
82- [inputs [key ] for inputs in micro_batch ["multi_modal_inputs" ]], dim = 0
83- )
84-
85- with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
86- input_ids = micro_batch ["input_ids" ]
87- batch_size , seqlen = input_ids .shape
88- attention_mask = micro_batch ["attention_mask" ]
89- position_ids = micro_batch ["position_ids" ]
90- if position_ids .dim () == 3 : # qwen2vl mrope
91- position_ids = position_ids .transpose (0 , 1 ) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
92-
93- if self .use_remove_padding :
94- input_ids_rmpad , indices , * _ = unpad_input (
95- input_ids .unsqueeze (- 1 ), attention_mask
96- ) # input_ids_rmpad (total_nnz, ...)
97- input_ids_rmpad = input_ids_rmpad .transpose (0 , 1 ) # (1, total_nnz)
98-
99- # unpad the position_ids to align the rotary
100- if position_ids .dim () == 3 :
101- position_ids_rmpad = (
102- index_first_axis (
103- rearrange (position_ids , "c b s ... -> (b s) c ..." ), indices
104- )
105- .transpose (0 , 1 )
106- .unsqueeze (1 )
107- ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
108- else :
109- position_ids_rmpad = index_first_axis (
110- rearrange (position_ids .unsqueeze (- 1 ), "b s ... -> (b s) ..." ), indices
111- ).transpose (0 , 1 )
112-
113- # for compute the log_prob
114- input_ids_rmpad_rolled = torch .roll (
115- input_ids_rmpad , shifts = - 1 , dims = 1
116- ) # (1, total_nnz)
117-
118- # pad and slice the inputs if sp > 1
119- if self .use_ulysses_sp :
120- input_ids_rmpad , position_ids_rmpad , pad_size = ulysses_pad_and_slice_inputs (
121- input_ids_rmpad ,
122- position_ids_rmpad ,
123- sp_size = self .ulysses_sequence_parallel_size ,
124- )
125- input_ids_rmpad_rolled , _ , _ = ulysses_pad_and_slice_inputs (
126- input_ids_rmpad_rolled , None , self .ulysses_sequence_parallel_size
127- )
128-
129- input_ids_rmpad_rolled = input_ids_rmpad_rolled .squeeze (
130- 0
131- ) # ((total_nnz / sp) + pad)
132-
133- # only pass input_ids and position_ids to enable flash_attn_varlen
134- output = self .actor_module (
135- input_ids = input_ids_rmpad ,
136- attention_mask = None ,
137- position_ids = position_ids_rmpad ,
138- ** multi_modal_inputs ,
139- use_cache = False ,
140- ) # prevent model thinks we are generating
141- logits_rmpad = output .logits .squeeze (0 ) # (total_nnz, vocab_size)
142-
143- logits_rmpad .div_ (temperature )
144-
145- # compute entropy
146- entropy_rmpad = self .compute_entropy_from_logits (
147- logits_rmpad
148- ) # ((total_nnz / sp) + pad)
149-
150- # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
151- log_probs = logprobs_from_logits (logits = logits_rmpad , labels = input_ids_rmpad_rolled )
152-
153- # gather log_prob if sp > 1
154- if self .use_ulysses_sp :
155- # gather and unpad for the ulysses sp
156- log_probs = gather_outpus_and_unpad (
157- log_probs , gather_dim = 0 , unpad_dim = 0 , padding_size = pad_size
158- )
159- entropy_rmpad = gather_outpus_and_unpad (
160- entropy_rmpad , gather_dim = 0 , unpad_dim = 0 , padding_size = pad_size
161- )
162- # pad back to (bsz, seqlen)
163- full_entropy = pad_input (
164- hidden_states = entropy_rmpad .unsqueeze (- 1 ),
165- indices = indices ,
166- batch = batch_size ,
167- seqlen = seqlen ,
168- )
169- full_log_probs = pad_input (
170- hidden_states = log_probs .unsqueeze (- 1 ),
171- indices = indices ,
172- batch = batch_size ,
173- seqlen = seqlen ,
174- )
175-
176- # only return response part:
177- entropy = full_entropy .squeeze (- 1 )[
178- :, - response_length - 1 : - 1
179- ] # (bsz, response_length)
180- log_probs = full_log_probs .squeeze (- 1 )[
181- :, - response_length - 1 : - 1
182- ] # (bsz, response_length)
183-
184- else : # not using rmpad and no ulysses sp
185- output = self .actor_module (
186- input_ids = input_ids ,
187- attention_mask = attention_mask ,
188- position_ids = position_ids ,
189- ** multi_modal_inputs ,
190- use_cache = False ,
191- ) # prevent model thinks we are generating
192- logits = output .logits
193- logits .div_ (temperature )
194- logits = logits [
195- :, - response_length - 1 : - 1 , :
196- ] # (bsz, response_length, vocab_size)
197- log_probs = logprobs_from_logits (logits , micro_batch ["responses" ])
198- entropy = verl_F .entropy_from_logits (logits ) # (bsz, response_length)
199-
200- return entropy , log_probs
201-
202- def _optimizer_step (self ):
203- assert self .config .grad_clip is not None
204-
205- if isinstance (self .actor_module , FSDP ):
206- grad_norm = self .actor_module .clip_grad_norm_ (max_norm = self .config .grad_clip )
207- else :
208- grad_norm = torch .nn .utils .clip_grad_norm_ (
209- self .actor_module .parameters (), max_norm = self .config .grad_clip
210- )
211- self .actor_optimizer .step ()
212- return grad_norm
213-
214- def compute_log_prob (self , data : DataProto ) -> torch .Tensor :
66+ @GPUMemoryLogger (role = "dp actor" , logger = logger )
67+ def compute_log_prob (self , data : DataProto , calculate_entropy = False ) -> torch .Tensor :
21568 """Compute the log probability of the responses given input_ids, attention_mask and position_ids
21669
21770 Args:
@@ -235,7 +88,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
23588 micro_batch_size = data .meta_info ["micro_batch_size" ]
23689 temperature = data .meta_info [
23790 "temperature"
238- ] # temperature must be in the data.meta_info to avoid slient error
91+ ] # temperature must be in the data.meta_info to avoid silent error
23992 use_dynamic_bsz = data .meta_info ["use_dynamic_bsz" ]
24093
24194 select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" ]
@@ -258,30 +111,40 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
258111 micro_batches = batch .split (micro_batch_size )
259112
260113 log_probs_lst = []
114+ entropy_lst = []
261115 for micro_batch in micro_batches :
262116 if isinstance (micro_batch , DataProto ):
263117 micro_batch = {** micro_batch .batch , ** micro_batch .non_tensor_batch }
264-
265118 with torch .no_grad ():
266- _ , log_probs = self ._forward_micro_batch (micro_batch , temperature = temperature )
119+ entropy , log_probs = self ._forward_micro_batch (
120+ micro_batch , temperature = temperature , calculate_entropy = calculate_entropy
121+ )
267122 log_probs_lst .append (log_probs )
268- log_probs = torch .concat (log_probs_lst , dim = 0 )
123+ if calculate_entropy :
124+ entropy_lst .append (entropy )
269125
126+ log_probs = torch .concat (log_probs_lst , dim = 0 )
127+ entropys = None
128+ if calculate_entropy :
129+ entropys = torch .concat (entropy_lst , dim = 0 )
270130 if use_dynamic_bsz :
271131 indices = list (itertools .chain .from_iterable (indices ))
272132 assert len (indices ) == log_probs .size (0 ), f"{ len (indices )} vs. { log_probs .size ()} "
273133 revert_indices = torch .tensor (get_reverse_idx (indices ), dtype = torch .long )
274134 log_probs = log_probs [revert_indices ]
135+ if calculate_entropy :
136+ entropys = entropys [revert_indices ] # type: ignore
275137
276- return log_probs
138+ return log_probs , entropys
277139
278- def update_policy (self , data : DataProto ): # noqa: C901
140+ @GPUMemoryLogger (role = "dp actor" , logger = logger )
141+ def update_policy (self , data : DataProto ):
279142 # make sure we are in training mode
280143 self .actor_module .train ()
281144
282145 temperature = data .meta_info [
283146 "temperature"
284- ] # temperature must be in the data.meta_info to avoid slient error
147+ ] # temperature must be in the data.meta_info to avoid silent error
285148 select_keys = [
286149 "input_ids" ,
287150 "position_ids" ,
@@ -351,12 +214,12 @@ def update_policy(self, data: DataProto): # noqa: C901
351214 # Support all hardwares
352215 if isinstance (data , DataProto ):
353216 data = {
354- ** data .batch .to (torch . cuda .current_device ()),
217+ ** data .batch .to (get_torch_device () .current_device ()),
355218 ** data .non_tensor_batch ,
356219 }
357220 else :
358221 data = data .to (
359- torch . cuda .current_device ()
222+ get_torch_device () .current_device ()
360223 ) # actor device is cpu when using offload
361224 responses = data ["responses" ]
362225 response_length = responses .size (1 )
@@ -365,8 +228,11 @@ def update_policy(self, data: DataProto): # noqa: C901
365228 assert response_mask .shape == attention_mask [:, - response_length :].shape
366229
367230 # all return: (bsz, response_length)
231+ calculate_entropy = self .entropy_loss_fn != DummyEntropyLossFn
368232 entropy , log_prob = self ._forward_micro_batch (
369- micro_batch = data , temperature = temperature
233+ micro_batch = data ,
234+ temperature = temperature ,
235+ calculate_entropy = calculate_entropy ,
370236 )
371237
372238 kwargs = {
0 commit comments