77# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88
99import asyncio
10- import time
1110import uuid
1211from dataclasses import dataclass
1312from typing import Any , Callable
1716import torchstore as ts
1817from datasets import load_dataset
1918from forge .actors .policy import Policy
20- from forge .actors .reference_model import ReferenceModel # noqa: F401
19+ from forge .actors .reference_model import ReferenceModel
2120from forge .actors .replay_buffer import ReplayBuffer
2221from forge .actors .torchstore_utils import get_param_key
23- from forge .actors .trainer import _qwen3_hf_to_vllm
22+ from forge .actors .trainer import RLTrainer
2423from forge .cli .config import parse
2524from forge .controller .actor import ForgeActor
2625from forge .controller .provisioner import shutdown
2726from forge .data .rewards import MathReward , ThinkingReward
28- from forge .losses .grpo_loss import SimpleGRPOLoss
2927from forge .util .metric_logging import get_metric_logger
3028from monarch .actor import endpoint
3129from omegaconf import DictConfig
32- from torchstore .state_dict_utils import DELIM
33- from torchtitan .config .job_config import Model as TitanJobModelConfig
34- from transformers import AutoModelForCausalLM
3530from vllm .transformers_utils .tokenizer import get_tokenizer
3631
3732
38- def compute_logprobs (
39- logits : torch .Tensor , input_ids : torch .Tensor , temperature : float = 1.0
40- ) -> torch .Tensor :
41- context_length = logits .shape [1 ] - input_ids .shape [1 ]
42-
43- # Truncate request logits and drop last
44- logits = logits [:, context_length - 1 : - 1 ]
45-
46- # Compute logprobs
47- logprobs = torch .log_softmax (logits / temperature , dim = - 1 )
48- logprobs = torch .gather (logprobs , 2 , input_ids .unsqueeze (- 1 )).squeeze (- 1 )
49-
50- return logprobs
51-
52-
5333@dataclass
5434class Episode :
5535 # TODO: add adtional layer for multi-turn
@@ -118,64 +98,64 @@ def new_group(
11898 return cls (str (group_id ), episodes )
11999
120100
121- @dataclass
122- class Trainer (ForgeActor ):
123- """GRPO Trainer implementation for policy optimization."""
124-
125- model_name : str
126- learning_rate : float = 1e-5
127- beta : float = 0.1
128- device : torch .device | None = None
129- state_dict_key : str = "model_state_dict"
130- dp_rank : int = 0 # TODO: support data parallelism, hard code it for now
131-
132- @endpoint
133- async def setup (self ):
134- if self .device is None :
135- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
136-
137- self .model = AutoModelForCausalLM .from_pretrained (
138- self .model_name ,
139- dtype = torch .bfloat16 ,
140- trust_remote_code = True ,
141- ).to (self .device )
142- self .model .train ()
143-
144- self .optimizer = torch .optim .AdamW (
145- self .model .parameters (), lr = self .learning_rate
146- )
147- self .optimizer .zero_grad ()
101+ def collate (batches : list [list [Episode ]]):
102+ inputs = []
103+ targets = []
104+ for batch in batches :
105+ request = [e .request_tensor for e in batch ]
106+ request = torch .stack (request ) # [b x s]
148107
149- self .loss = SimpleGRPOLoss (self .beta )
108+ response = [e .response_tensor for e in batch ]
109+ response = torch .stack (response ) # [b x s]
150110
151- self .logger .info (f"Trainer model initialized on { self .device } " )
111+ ref_logprobs = [e .ref_logprobs for e in batch ]
112+ ref_logprobs = torch .stack (ref_logprobs ).squeeze () # [b x s]
152113
153- @endpoint
154- async def train_step (self , batch : list [list [Episode ]]):
155- microbatch = batch [self .dp_rank ]
156- pad_id = microbatch [0 ].pad_id
114+ advantages = [e .advantage for e in batch ]
115+ advantages = torch .tensor (advantages ).unsqueeze (- 1 ) # [b x 1]
157116
158- # prepare batch
159- request = [e .request_tensor for e in microbatch ]
160- request = torch .stack (request ).to (self .device ) # [b x s]
117+ pad_id = batch [0 ].pad_id
118+ mask = response != pad_id
161119
162- response = [e .response_tensor for e in microbatch ]
163- response = torch .stack (response ).to (self .device ) # [b x s]
120+ input = {"tokens" : torch .cat ([request , response ], dim = 1 )}
121+ target = {
122+ "response" : response ,
123+ "ref_logprobs" : ref_logprobs ,
124+ "advantages" : advantages ,
125+ "padding_mask" : mask ,
126+ }
127+ inputs .append (input )
128+ targets .append (target )
129+ return inputs , targets
164130
165- ref_logprobs = [e .ref_logprobs for e in microbatch ]
166- ref_logprobs = torch .stack (ref_logprobs ).to (self .device ).squeeze () # [b x s]
167131
168- advantages = [e .advantage for e in microbatch ]
169- advantages = torch .tensor (advantages ).to (self .device ).unsqueeze (- 1 ) # [b x 1]
170- del batch
132+ def compute_logprobs (
133+ logits : torch .Tensor , input_ids : torch .Tensor , temperature : float = 1.0
134+ ) -> torch .Tensor :
135+ context_length = logits .shape [1 ] - input_ids .shape [1 ]
136+ logits = logits [:, context_length - 1 : - 1 ]
137+ logprobs = torch .log_softmax (logits / temperature , dim = - 1 ).to (input_ids .device )
138+ logprobs = torch .gather (logprobs , 2 , input_ids .unsqueeze (- 1 )).squeeze (- 1 )
139+ return logprobs
171140
172- input_ids = torch .cat ([request , response ], dim = 1 )
173- mask = input_ids != pad_id
174- logits = self .model (input_ids = input_ids , attention_mask = mask ).logits
175- logprobs = compute_logprobs (logits , response )
176- del logits
177141
178- mask = response != pad_id
142+ def simple_grpo_loss (
143+ logits : torch .Tensor ,
144+ response : torch .Tensor ,
145+ ref_logprobs : torch .Tensor ,
146+ advantages : torch .Tensor ,
147+ padding_mask : torch .Tensor ,
148+ beta : float = 0.1 ,
149+ ) -> torch .Tensor :
150+ logprobs = compute_logprobs (logits , response )
151+ kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
152+ per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
153+ per_token_loss = - (per_token_policy_loss - beta * kl )
154+ loss = (
155+ ((per_token_loss * padding_mask ).sum (dim = 1 ))
156+ / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
157+ ).mean ()
158+ return loss
179159 loss = self .loss (logprobs , ref_logprobs , advantages , mask )
180160 loss .backward ()
181161 self .optimizer .step ()
@@ -223,38 +203,6 @@ async def compute(self, group: Group) -> list[float]:
223203 return advantages .squeeze (0 ).tolist ()
224204
225205
226- class RefModel (ForgeActor ):
227- def __init__ (self , model_name , device : torch .device | None = None ):
228- super ().__init__ ()
229- self .model_name = model_name
230-
231- if device is None :
232- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
233- else :
234- self .device = device
235-
236- self .model = AutoModelForCausalLM .from_pretrained (
237- model_name ,
238- dtype = torch .bfloat16 ,
239- trust_remote_code = True ,
240- ).to (self .device )
241- self .model .eval ()
242-
243- self .logger .info (f"Model initialized on { self .device } " )
244-
245- @endpoint
246- async def forward (self , episode : Episode ) -> torch .Tensor :
247- req , res = episode .request_tensor , episode .response_tensor
248- input_ids = torch .cat ([req , res ]).to (self .device ).unsqueeze (0 )
249- mask = input_ids != episode .pad_id
250-
251- with torch .inference_mode ():
252- logits = self .model (input_ids = input_ids , attention_mask = mask ).logits
253-
254- input_ids = input_ids [:, len (req ) :]
255- return compute_logprobs (logits , input_ids )
256-
257-
258206@dataclass
259207class DatasetActor (ForgeActor ):
260208 """Actor wrapper for HuggingFace dataset to provide async interface."""
@@ -309,10 +257,7 @@ async def pad_token(self):
309257
310258async def main (cfg : DictConfig ):
311259 """Main GRPO training loop with rollout and training processes."""
312- titan_model = TitanJobModelConfig (name = "qwen3" , flavor = "1.7B" )
313- # Get parameters from config with fallbacks
314260 group_size = cfg .group_size
315- model = cfg .model
316261 max_req_tokens = cfg .max_req_tokens
317262 max_res_tokens = cfg .max_res_tokens
318263 mlogger = get_metric_logger (
@@ -322,7 +267,7 @@ async def main(cfg: DictConfig):
322267 )
323268
324269 # ---- Setup services ---- #
325- await ts .initialize ()
270+ await ts .initialize (strategy = ts . ControllerStorageVolumes () )
326271 (
327272 dataloader ,
328273 policy ,
@@ -334,17 +279,18 @@ async def main(cfg: DictConfig):
334279 ) = await asyncio .gather (
335280 DatasetActor .options (** cfg .services .dataset ).as_service (** cfg .dataset ),
336281 Policy .options (** cfg .services .policy ).as_service (** cfg .policy ),
337- Trainer .options (** cfg .services .trainer ).as_service (** cfg .trainer ),
282+ RLTrainer .options (** cfg .services .trainer ).as_service (
283+ ** cfg .trainer , loss = simple_grpo_loss
284+ ),
338285 ReplayBuffer .options (** cfg .services .replay_buffer ).as_service (
339- ** cfg .replay_buffer
286+ ** cfg .replay_buffer , collate = collate
340287 ),
341288 ComputeAdvantages .options (** cfg .services .compute_advantages ).as_service (),
342- RefModel .options (** cfg .services .ref_model ).as_service (** cfg .ref_model ),
289+ ReferenceModel .options (** cfg .services .ref_model ).as_service (** cfg .ref_model ),
343290 RewardActor .options (** cfg .services .reward_actor ).as_service (
344291 reward_functions = [MathReward (), ThinkingReward ()]
345292 ),
346293 )
347-
348294 print ("All services initialized successfully!" )
349295
350296 # ---- Core RL loops ---- #
@@ -358,6 +304,7 @@ async def continuous_rollouts():
358304 return
359305 prompt , target = sample ["request" ], sample ["target" ]
360306 responses = await policy .generate .choose (prompt )
307+ # TODO: this shall be part of the responses metadata instead of a separate call
361308 version = await policy .get_version .choose ()
362309 group = Group .new_group (
363310 group_id = rollout_count ,
@@ -370,20 +317,36 @@ async def continuous_rollouts():
370317 target = target ,
371318 )
372319
373- # TODO: Parallelize the following calculation
374- for episode , response in zip (group .episodes , responses .outputs ):
375- episode .request_tokens = responses .prompt_token_ids
320+ input_ids = torch .ones (
321+ (group_size , max_req_tokens + max_req_tokens ),
322+ dtype = torch .long ,
323+ device = "cuda" ,
324+ )
325+ # Populate episode info and calculate rewards
326+ for i , (episode , response ) in enumerate (zip (group .episodes , responses )):
327+ episode .request_tokens = response .prompt_ids
376328 episode .response_tokens = response .token_ids
377329 episode .response = response .text
378- episode .ref_logprobs = await ref_model .forward .choose (episode )
330+ input_ids [i , :max_req_tokens ] = episode .request_tensor
331+ input_ids [i , max_req_tokens :] = episode .response_tensor
379332 episode .reward = await reward_actor .evaluate_response .choose (
380333 prompt = prompt , response = response .text , target = target
381334 )
335+
336+ # Calculate reference logprobs
337+ ref_logits = await ref_model .forward .choose (input_ids )
338+ ref_logprobs = compute_logprobs (ref_logits , input_ids [:, max_req_tokens :])
339+ for i , episode in enumerate (group .episodes ):
340+ episode .ref_logprobs = ref_logprobs [i ]
341+ del ref_logits , ref_logprobs , input_ids
342+
343+ # Calculate advantages and add to replay buffer
382344 advantages = await compute_advantages .compute .choose (group )
383345 for episode , advantage in zip (group .episodes , advantages ):
384346 episode .advantage = advantage
385347 await replay_buffer .add .choose (episode )
386348
349+ # Log metrics
387350 avg_response_len = (
388351 sum (len (e .response_tokens ) for e in group .episodes ) / group_size
389352 )
@@ -402,7 +365,8 @@ async def continuous_training():
402365 if batch is None :
403366 await asyncio .sleep (0.1 )
404367 else :
405- loss = await trainer .train_step .choose (batch )
368+ inputs , targets = batch
369+ loss = await trainer .train_step .choose (inputs , targets )
406370 training_step += 1
407371 mlogger .log ("loss/training_step" , loss , training_step )
408372 start_time = time .perf_counter ()
0 commit comments