55# LICENSE file in the root directory of this source tree.
66
77import asyncio
8+ import logging
89import time
910from dataclasses import dataclass
1011from typing import Callable
1112
1213import torch
1314from datasets import load_dataset
1415from forge .actors .policy import Policy , PolicyConfig , SamplingOverrides , WorkerConfig
16+ from forge .actors .reference_actor import compute_sequence_logprobs , TitanRefModel
1517from forge .actors .replay_buffer import ReplayBuffer
16- from forge .controller import ServiceConfig , spawn_service
1718from forge .controller .actor import ForgeActor
19+ from forge .controller .service import ServiceConfig , shutdown_service , spawn_service
1820from forge .data .rewards import MathReward , ThinkingReward
1921from forge .util .metric_logging import get_metric_logger
2022from monarch .actor import endpoint
23+ from torchtitan .config .job_config import Model as TitanJobModelConfig
2124from transformers import AutoModelForCausalLM , AutoTokenizer
2225
23-
24- def compute_sequence_logprobs (
25- model : torch .nn .Module ,
26- input_ids : torch .Tensor ,
27- attention_mask : torch .Tensor ,
28- requires_grad : bool = True ,
29- ) -> torch .Tensor :
30- context_manager = torch .enable_grad () if requires_grad else torch .no_grad ()
31-
32- with context_manager :
33- outputs = model (input_ids = input_ids , attention_mask = attention_mask )
34- logits = outputs .logits
35-
36- # Apply log softmax to get log probabilities
37- log_probs = torch .log_softmax (logits , dim = - 1 )
38-
39- # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
40- shifted_input_ids = input_ids [:, 1 :] # Remove first token
41- shifted_log_probs = log_probs [:, :- 1 , :] # Remove last logit
42-
43- # Gather log probabilities for actual tokens
44- token_log_probs = torch .gather (
45- shifted_log_probs , dim = - 1 , index = shifted_input_ids .unsqueeze (- 1 )
46- ).squeeze (- 1 )
47-
48- # Sum log probabilities across sequence (masked by attention)
49- shifted_attention_mask = attention_mask [:, 1 :]
50- sequence_log_probs = (token_log_probs * shifted_attention_mask ).sum (dim = - 1 )
51-
52- return sequence_log_probs
26+ logger = logging .getLogger (__name__ )
27+ logger .setLevel (logging .DEBUG )
5328
5429
5530@dataclass
@@ -269,63 +244,21 @@ async def __call__(self, groups: list[Group]) -> list[float]:
269244 return advantages
270245
271246
272- class RefModel (ForgeActor ):
273- def __init__ (self , model_name , device : torch .device | None = None ):
274- super ().__init__ ()
275- self .model_name = model_name
276-
277- # Set device
278- if device is None :
279- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
280- else :
281- self .device = device
282-
283- # Initialize model and tokenizer
284- self .model = AutoModelForCausalLM .from_pretrained (
285- model_name ,
286- torch_dtype = torch .bfloat16 ,
287- trust_remote_code = True ,
288- ).to (self .device )
289-
290- # Set model to eval mode for reference computations
291- self .model .eval ()
292-
293- self .logger .info (f"Model initialized on { self .device } " )
294-
295- @endpoint
296- async def forward (self , token_ids : list [int ]) -> torch .Tensor :
297- # Use provided token_ids directly
298- input_ids = (
299- torch .tensor (token_ids , dtype = torch .long ).unsqueeze (0 ).to (self .device )
300- )
301- # Create attention mask of all 1s since we have actual tokens (no padding)
302- attention_mask = torch .ones_like (input_ids ).to (self .device )
303-
304- # Compute log probabilities using shared utility function
305- sequence_log_probs = compute_sequence_logprobs (
306- self .model , input_ids , attention_mask , requires_grad = False
307- )
308-
309- return (
310- sequence_log_probs .squeeze ()
311- ) # Remove batch dimension for single response
312-
313-
314247class DatasetActor (ForgeActor ):
315248 """Actor wrapper for HuggingFace dataset to provide async interface."""
316249
317- def __init__ (self , * args , ** kwargs ):
250+ def __init__ (
251+ self , path : str , config_name : str , split : str , streaming : bool , ** kwargs
252+ ):
318253 super ().__init__ ()
319- self ._setup_dataset (* args , ** kwargs )
320254
321- def _setup_dataset (self , * args , ** kwargs ):
322255 def gsm8k_to_messages (sample ):
323256 question = sample ["question" ]
324257 full_answer : str = sample ["answer" ]
325258 answer = full_answer .split ("#### " )[1 ]
326259 return {"question" : question , "answer" : answer }
327260
328- ds = load_dataset (* args , ** kwargs )
261+ ds = load_dataset (path , config_name , split = split , streaming = streaming )
329262 ds = ds .map (gsm8k_to_messages )
330263 ds = ds .shuffle ()
331264 self ._iterator = iter (ds )
@@ -341,7 +274,8 @@ async def __next__(self) -> dict[str, str] | None:
341274async def main ():
342275 """Main GRPO training loop with rollout and training processes."""
343276 group_size = 1
344- model = "Qwen/Qwen3-1.7B"
277+ model = "Qwen/Qwen3-0.6B"
278+ titan_model = TitanJobModelConfig (name = "qwen3" , flavor = "0.6B" )
345279
346280 # ---- Setup WandB Logger ---- #
347281 logger = get_metric_logger (
@@ -351,74 +285,69 @@ async def main():
351285 )
352286
353287 # ---- Setup services ---- #
354- default_service_cfg = ServiceConfig (
355- procs_per_replica = 1 ,
356- num_replicas = 1 ,
357- )
358-
359- policy = await spawn_service (
360- default_service_cfg ,
361- Policy ,
362- PolicyConfig (
363- num_workers = 1 ,
364- worker_params = WorkerConfig (model = model ),
365- sampling_params = SamplingOverrides (num_samples = group_size , max_tokens = 16 ),
366- available_devices = "3" ,
288+ (
289+ dataloader ,
290+ policy ,
291+ trainer ,
292+ replay_buffer ,
293+ compute_advantages ,
294+ ref_model ,
295+ reward_actor ,
296+ ) = await asyncio .gather (
297+ spawn_service (
298+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
299+ DatasetActor ,
300+ path = "openai/gsm8k" ,
301+ config_name = "main" ,
302+ split = "train" ,
303+ streaming = True ,
304+ ),
305+ spawn_service (
306+ ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
307+ Policy ,
308+ config = PolicyConfig (
309+ worker_params = WorkerConfig (model = model ),
310+ sampling_params = SamplingOverrides (
311+ num_samples = group_size , max_tokens = 16
312+ ),
313+ ),
314+ ),
315+ spawn_service (
316+ ServiceConfig (procs_per_replica = 1 , with_gpus = True , num_replicas = 1 ),
317+ Trainer ,
318+ learning_rate = 1e-5 ,
319+ beta = 0.1 ,
320+ model_name = model ,
321+ ),
322+ spawn_service (
323+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
324+ ReplayBuffer ,
325+ batch_size = 4 ,
326+ max_policy_age = 1 ,
327+ ),
328+ spawn_service (
329+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
330+ ComputeAdvantages ,
331+ gamma = 0.99 ,
332+ lambda_ = 0.95 ,
333+ ),
334+ spawn_service (
335+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 , with_gpus = True ),
336+ TitanRefModel ,
337+ model = titan_model ,
338+ ),
339+ spawn_service (
340+ ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
341+ RewardActor ,
342+ reward_functions = [MathReward (), ThinkingReward ()],
367343 ),
368- )
369-
370- trainer = await spawn_service (
371- default_service_cfg ,
372- Trainer ,
373- learning_rate = 1e-5 ,
374- beta = 0.1 ,
375- model_name = model ,
376- device = torch .device ("cuda:1" ),
377- )
378-
379- replay_buffer = await spawn_service (
380- default_service_cfg ,
381- ReplayBuffer ,
382- batch_size = 4 ,
383- max_policy_age = 1 ,
384- )
385-
386- dataloader = await spawn_service (
387- default_service_cfg ,
388- DatasetActor ,
389- "openai/gsm8k" ,
390- "main" ,
391- split = "train" ,
392- streaming = True ,
393- )
394-
395- compute_advantages = await spawn_service (
396- default_service_cfg ,
397- ComputeAdvantages ,
398- gamma = 0.99 ,
399- lambda_ = 0.95 ,
400- )
401-
402- ref_model = await spawn_service (
403- default_service_cfg ,
404- RefModel ,
405- model_name = model ,
406- device = torch .device ("cuda:2" ),
407- )
408-
409- reward_actor = await spawn_service (
410- default_service_cfg ,
411- RewardActor ,
412- reward_functions = [MathReward (), ThinkingReward ()],
413344 )
414345
415346 print ("All services initialized successfully!" )
416347
417348 # ---- Core RL loops ---- #
418349 async def continuous_rollouts ():
419350 rollout_count = 0
420- # TODO: Move this into setup
421- asyncio .create_task (policy .run_processing .call ())
422351 while True :
423352 sample = await dataloader .__next__ .choose ()
424353 if sample is None :
@@ -432,9 +361,14 @@ async def continuous_rollouts():
432361 target = target ,
433362 policy_version = version ,
434363 )
435- actions = await policy .generate .choose (prompt )
364+ responses = await policy .generate .choose (prompt )
365+ actions = responses .outputs
436366 for action in actions :
437- ref_logprobs = await ref_model .forward .choose (action .token_ids )
367+ request_tokens = responses .prompt_token_ids
368+ response_tokens = action .token_ids
369+ ref_logprobs = await ref_model .forward .choose (
370+ request = request_tokens , response = response_tokens
371+ )
438372 reward = await reward_actor .evaluate_response .choose (
439373 prompt = prompt , response = action .text , target = target
440374 )
@@ -489,6 +423,17 @@ async def continuous_training():
489423 print ("Training interrupted by user" )
490424 rollout_task .cancel ()
491425 training_task .cancel ()
426+ finally :
427+ print ("Shutting down..." )
428+ await asyncio .gather (
429+ shutdown_service (policy ),
430+ shutdown_service (trainer ),
431+ shutdown_service (replay_buffer ),
432+ shutdown_service (dataloader ),
433+ shutdown_service (compute_advantages ),
434+ shutdown_service (ref_model ),
435+ shutdown_service (reward_actor ),
436+ )
492437
493438
494439if __name__ == "__main__" :
0 commit comments