1313import torch
1414from datasets import load_dataset
1515from forge .actors .policy import Policy , PolicyConfig , SamplingOverrides , WorkerConfig
16+ from forge .actors .reference_actor import compute_sequence_logprobs , TitanRefModel
1617from forge .actors .replay_buffer import ReplayBuffer
1718from forge .controller .actor import ForgeActor
1819from forge .controller .service import ServiceConfig , shutdown_service , spawn_service
1920from forge .data .rewards import MathReward , ThinkingReward
2021from forge .util .metric_logging import get_metric_logger
2122from monarch .actor import endpoint
23+ from torchtitan .config .job_config import Model as TitanJobModelConfig
2224from transformers import AutoModelForCausalLM , AutoTokenizer
2325
2426logger = logging .getLogger (__name__ )
2527logger .setLevel (logging .DEBUG )
2628
2729
28- def compute_sequence_logprobs (
29- model : torch .nn .Module ,
30- input_ids : torch .Tensor ,
31- attention_mask : torch .Tensor ,
32- requires_grad : bool = True ,
33- ) -> torch .Tensor :
34- context_manager = torch .enable_grad () if requires_grad else torch .no_grad ()
35-
36- with context_manager :
37- outputs = model (input_ids = input_ids , attention_mask = attention_mask )
38- logits = outputs .logits
39-
40- # Apply log softmax to get log probabilities
41- log_probs = torch .log_softmax (logits , dim = - 1 )
42-
43- # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
44- shifted_input_ids = input_ids [:, 1 :] # Remove first token
45- shifted_log_probs = log_probs [:, :- 1 , :] # Remove last logit
46-
47- # Gather log probabilities for actual tokens
48- token_log_probs = torch .gather (
49- shifted_log_probs , dim = - 1 , index = shifted_input_ids .unsqueeze (- 1 )
50- ).squeeze (- 1 )
51-
52- # Sum log probabilities across sequence (masked by attention)
53- shifted_attention_mask = attention_mask [:, 1 :]
54- sequence_log_probs = (token_log_probs * shifted_attention_mask ).sum (dim = - 1 )
55-
56- return sequence_log_probs
57-
58-
5930@dataclass
6031class Group :
6132 response : str # The response text for tokenization
@@ -273,48 +244,6 @@ async def __call__(self, groups: list[Group]) -> list[float]:
273244 return advantages
274245
275246
276- class RefModel (ForgeActor ):
277- def __init__ (self , model_name , device : torch .device | None = None ):
278- super ().__init__ ()
279- self .model_name = model_name
280-
281- # Set device
282- if device is None :
283- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
284- else :
285- self .device = device
286-
287- # Initialize model and tokenizer
288- self .model = AutoModelForCausalLM .from_pretrained (
289- model_name ,
290- torch_dtype = torch .bfloat16 ,
291- trust_remote_code = True ,
292- ).to (self .device )
293-
294- # Set model to eval mode for reference computations
295- self .model .eval ()
296-
297- self .logger .info (f"Model initialized on { self .device } " )
298-
299- @endpoint
300- async def forward (self , token_ids : list [int ]) -> torch .Tensor :
301- # Use provided token_ids directly
302- input_ids = (
303- torch .tensor (token_ids , dtype = torch .long ).unsqueeze (0 ).to (self .device )
304- )
305- # Create attention mask of all 1s since we have actual tokens (no padding)
306- attention_mask = torch .ones_like (input_ids ).to (self .device )
307-
308- # Compute log probabilities using shared utility function
309- sequence_log_probs = compute_sequence_logprobs (
310- self .model , input_ids , attention_mask , requires_grad = False
311- )
312-
313- return (
314- sequence_log_probs .squeeze ()
315- ) # Remove batch dimension for single response
316-
317-
318247class DatasetActor (ForgeActor ):
319248 """Actor wrapper for HuggingFace dataset to provide async interface."""
320249
@@ -345,7 +274,8 @@ async def __next__(self) -> dict[str, str] | None:
345274async def main ():
346275 """Main GRPO training loop with rollout and training processes."""
347276 group_size = 1
348- model = "Qwen/Qwen3-1.7B"
277+ model = "Qwen/Qwen3-0.6B"
278+ titan_model = TitanJobModelConfig (name = "qwen3" , flavor = "0.6B" )
349279
350280 # ---- Setup WandB Logger ---- #
351281 logger = get_metric_logger (
@@ -403,8 +333,8 @@ async def main():
403333 ),
404334 spawn_service (
405335 ServiceConfig (procs_per_replica = 1 , num_replicas = 1 , with_gpus = True ),
406- RefModel ,
407- model_name = model ,
336+ TitanRefModel ,
337+ model = titan_model ,
408338 ),
409339 spawn_service (
410340 ServiceConfig (procs_per_replica = 1 , num_replicas = 1 ),
@@ -431,9 +361,14 @@ async def continuous_rollouts():
431361 target = target ,
432362 policy_version = version ,
433363 )
434- actions = await policy .generate .choose (prompt )
364+ responses = await policy .generate .choose (prompt )
365+ actions = responses .outputs
435366 for action in actions :
436- ref_logprobs = await ref_model .forward .choose (action .token_ids )
367+ request_tokens = responses .prompt_token_ids
368+ response_tokens = action .token_ids
369+ ref_logprobs = await ref_model .forward .choose (
370+ request = request_tokens , response = response_tokens
371+ )
437372 reward = await reward_actor .evaluate_response .choose (
438373 prompt = prompt , response = action .text , target = target
439374 )
0 commit comments