44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- # Usage: python -m apps.trainer.main --config apps/trainer/trainer_config .yaml
7+ # Usage: python -m apps.trainer.main --config apps/grpo/qwen3_32b .yaml
88
99import asyncio
1010
1111import torch
12- import torch .nn .functional as F
13-
12+ import torchstore as ts
1413from forge .actors .trainer import RLTrainer
1514from forge .cli .config import parse
1615from forge .controller .launcher import JOB_NAME_KEY , LAUNCHER_KEY
1716from forge .controller .provisioner import init_provisioner , shutdown
1817from forge .observability .metric_actors import get_or_create_metric_logger
18+ from forge .observability .metrics import record_metric , Reduce
19+ from forge .observability .perf_tracker import Tracer
1920from forge .types import (
2021 Launcher ,
2122 LauncherConfig ,
2425 ServiceConfig ,
2526)
2627from omegaconf import DictConfig
28+ from vllm .transformers_utils .tokenizer import get_tokenizer
29+
30+
31+ def simple_grpo_loss (
32+ logits : torch .Tensor ,
33+ response : torch .Tensor ,
34+ ref_logprobs : torch .Tensor ,
35+ advantages : torch .Tensor ,
36+ padding_mask : torch .Tensor ,
37+ beta : float = 0.1 ,
38+ ) -> torch .Tensor :
39+ """
40+ Simplified loss function for memory/CPU profiling purposes.
41+ Just performs basic tensor operations to simulate memory usage.
42+ """
43+ # Extract dimensions
44+ batch_size , response_len = response .shape
45+ vocab_size = logits .size (- 1 )
46+ full_seq_len = logits .size (1 )
47+
48+ # Extract only the response portion from logits
49+ # logits shape: [batch_size, request_len + response_len, vocab_size]
50+ # We want the last response_len tokens
51+ request_len = full_seq_len - response_len
52+ response_logits = logits [
53+ :, request_len :, :
54+ ] # [batch_size, response_len, vocab_size]
55+
56+ # Flatten logits and response for cross-entropy
57+ logits_flat = response_logits .reshape (- 1 , vocab_size )
58+ response_flat = response .reshape (- 1 )
59+
60+ # Basic cross-entropy loss (simplified)
61+ loss = torch .nn .functional .cross_entropy (
62+ logits_flat , response_flat , reduction = "none"
63+ ).view (batch_size , response_len )
64+
65+ # Apply padding mask and reduce
66+ masked_loss = loss * padding_mask
67+ loss = masked_loss .sum () / padding_mask .sum ().clamp (min = 1.0 )
68+
69+ return loss
70+
71+
72+ def generate_random_batch (
73+ batch_size : int ,
74+ request_len : int ,
75+ response_len : int ,
76+ vocab_size : int = 32000 ,
77+ device : str = "cuda" ,
78+ dp_size : int = 1 ,
79+ ):
80+ """
81+ Generate random input and target tensors matching GRPO data format
82+ Creates one batch per data parallel rank
83+ """
84+ inputs = []
85+ targets = []
86+
87+ # Create one batch for each data parallel rank
88+ for _ in range (dp_size ):
89+ request = torch .randint (
90+ 1 , vocab_size , (batch_size , request_len ), dtype = torch .long , device = device
91+ )
92+ response = torch .randint (
93+ 1 , vocab_size , (batch_size , response_len ), dtype = torch .long , device = device
94+ )
95+
96+ # Create padding mask (randomly mask some tokens as padding)
97+ padding_mask = torch .rand ((batch_size , response_len ), device = device ) > 0.1
2798
99+ ref_logprobs = (
100+ - torch .abs (torch .randn ((batch_size , response_len ), device = device )) - 1.0
101+ )
102+ advantages = torch .randn ((batch_size , 1 ), device = device )
103+ input_tokens = torch .cat ([request , response ], dim = 1 )
104+ inputs .append ({"tokens" : input_tokens })
105+ targets .append (
106+ {
107+ "response" : response ,
108+ "ref_logprobs" : ref_logprobs ,
109+ "advantages" : advantages ,
110+ "padding_mask" : padding_mask ,
111+ }
112+ )
28113
29- def placeholder_loss_function (logits , targets ):
30- return F .cross_entropy (logits .view (- 1 , logits .size (- 1 )), targets .view (- 1 ))
114+ return inputs , targets
31115
32116
33117async def main (cfg : DictConfig ):
34- """Main function that only initializes the trainer."""
118+ """
119+ Trainer simulation app for memory/CPU profiling and system usage analysis.
120+
121+ This app initializes only the RLTrainer component and runs a training loop with
122+ synthetic random data to simulate real trainer system usage patterns. It is
123+ designed for:
124+
125+ - Memory profiling of trainer infrastructure
126+ - CPU usage analysis during training steps
127+ - System resource monitoring (GPU memory, network, etc.)
128+ - Performance benchmarking of trainer components
129+ - Testing trainer stability under load
130+
131+ The app uses the same configuration format as GRPO but bypasses policy generation,
132+ replay buffers, and reward computation, focusing purely on the trainer's
133+ computational and memory characteristics with realistic data shapes.
134+ """
135+
136+ # Extract training parameters from existing GRPO config fields
137+ batch_size = cfg .get ("batch_size" , 4 )
138+ request_len = cfg .get ("max_req_tokens" , 128 )
139+ response_len = cfg .get ("max_res_tokens" , 128 )
140+ max_training_steps = cfg .trainer .training .get ("steps" , 100 )
141+
142+ # Get vocab size from the actual model tokenizer
143+ model_name = cfg .get ("model" )
144+ print (f"Loading tokenizer for model: { model_name } " )
145+ tokenizer = get_tokenizer (model_name )
146+ vocab_size = tokenizer .vocab_size
147+ pad_id = tokenizer .pad_token_id if tokenizer .pad_token_id is not None else 0
148+ print (f"Detected vocab size: { vocab_size } , pad token ID: { pad_id } " )
149+
150+ # Get data parallel size from replay buffer config (which matches trainer DP degree)
151+ dp_size = cfg .get ("replay_buffer" , {}).get ("dp_size" , 1 )
152+ if dp_size is None :
153+ # Fallback to trainer config if replay_buffer.dp_size not set
154+ trainer_dp_degree = cfg .trainer .parallelism .get ("data_parallel_shard_degree" , 1 )
155+ dp_size = trainer_dp_degree if trainer_dp_degree != - 1 else 1
35156
36- # Initialize provisioner
37157 await init_provisioner (
38158 ProvisionerConfig (
39159 launcher_config = LauncherConfig (
@@ -45,29 +165,65 @@ async def main(cfg: DictConfig):
45165 )
46166 )
47167
48- # Initialize metric logging
49168 metric_logging_cfg = cfg .get ("metric_logging" , {"console" : {"log_per_rank" : False }})
50169 mlogger = await get_or_create_metric_logger ()
51170 await mlogger .init_backends .call_one (metric_logging_cfg )
52171
172+ await ts .initialize (strategy = ts .ControllerStorageVolumes ())
53173 # Initialize trainer only
54174 print ("Initializing trainer..." )
55175 trainer = await RLTrainer .options (** cfg .actors .trainer ).as_actor (
56- ** cfg .trainer , loss = placeholder_loss_function
176+ ** cfg .trainer , loss = simple_grpo_loss
57177 )
58-
59178 print ("Trainer initialized successfully!" )
60- print (f"Trainer configuration: { cfg .trainer } " )
179+ print (f"Training configuration:" )
180+ print (f" - Batch size: { batch_size } " )
181+ print (f" - Request length: { request_len } " )
182+ print (f" - Response length: { response_len } " )
183+ print (f" - Vocab size: { vocab_size } " )
184+ print (f" - Data parallel size: { dp_size } " )
185+ print (f" - Max training steps: { max_training_steps } " )
186+
187+ async def continuous_training ():
188+ training_step = 0
189+
190+ print ("Starting training loop with random data..." )
191+ while training_step < max_training_steps :
192+ t = Tracer ("trainer/continuous_training" )
193+ t .start ()
194+
195+ inputs , targets = generate_random_batch (
196+ batch_size = batch_size ,
197+ request_len = request_len ,
198+ response_len = response_len ,
199+ vocab_size = vocab_size ,
200+ dp_size = dp_size ,
201+ )
202+ t .step ("generate_random_data" )
203+
204+ # Perform training step
205+ await trainer .train_step .call (inputs , targets )
206+ training_step += 1
207+ t .step ("train_step" )
208+
209+ await trainer .push_weights .call (training_step )
210+ t .step ("push_weights" )
211+ t .stop ()
212+
213+ # Flush metrics
214+ await mlogger .flush .call_one (training_step )
215+
216+ print (f"Completed training step { training_step } /{ max_training_steps } " )
217+
218+ # Sleep between steps to avoid overwhelming the system
219+ await asyncio .sleep (1.0 )
61220
62- # Keep the trainer running for demonstration
63- # In a real scenario, you might want to expose endpoints or do other work here
64221 try :
65- print ("Trainer is running. Press Ctrl+C to shutdown..." )
66- while True :
67- await asyncio .sleep (1 )
222+ await continuous_training ()
68223 except KeyboardInterrupt :
69- print ("Shutting down trainer... " )
224+ print ("Training interrupted by user " )
70225 finally :
226+ print ("Shutting down trainer..." )
71227 await RLTrainer .shutdown (trainer )
72228 await mlogger .shutdown .call_one ()
73229 await shutdown ()
0 commit comments