Skip to content

Commit d7d89d5

Browse files
committed
update
1 parent 8d3afda commit d7d89d5

File tree

1 file changed

+173
-17
lines changed

1 file changed

+173
-17
lines changed

apps/trainer/main.py

Lines changed: 173 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# Usage: python -m apps.trainer.main --config apps/trainer/trainer_config.yaml
7+
# Usage: python -m apps.trainer.main --config apps/grpo/qwen3_32b.yaml
88

99
import asyncio
1010

1111
import torch
12-
import torch.nn.functional as F
13-
12+
import torchstore as ts
1413
from forge.actors.trainer import RLTrainer
1514
from forge.cli.config import parse
1615
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
1716
from forge.controller.provisioner import init_provisioner, shutdown
1817
from forge.observability.metric_actors import get_or_create_metric_logger
18+
from forge.observability.metrics import record_metric, Reduce
19+
from forge.observability.perf_tracker import Tracer
1920
from forge.types import (
2021
Launcher,
2122
LauncherConfig,
@@ -24,16 +25,135 @@
2425
ServiceConfig,
2526
)
2627
from omegaconf import DictConfig
28+
from vllm.transformers_utils.tokenizer import get_tokenizer
29+
30+
31+
def simple_grpo_loss(
32+
logits: torch.Tensor,
33+
response: torch.Tensor,
34+
ref_logprobs: torch.Tensor,
35+
advantages: torch.Tensor,
36+
padding_mask: torch.Tensor,
37+
beta: float = 0.1,
38+
) -> torch.Tensor:
39+
"""
40+
Simplified loss function for memory/CPU profiling purposes.
41+
Just performs basic tensor operations to simulate memory usage.
42+
"""
43+
# Extract dimensions
44+
batch_size, response_len = response.shape
45+
vocab_size = logits.size(-1)
46+
full_seq_len = logits.size(1)
47+
48+
# Extract only the response portion from logits
49+
# logits shape: [batch_size, request_len + response_len, vocab_size]
50+
# We want the last response_len tokens
51+
request_len = full_seq_len - response_len
52+
response_logits = logits[
53+
:, request_len:, :
54+
] # [batch_size, response_len, vocab_size]
55+
56+
# Flatten logits and response for cross-entropy
57+
logits_flat = response_logits.reshape(-1, vocab_size)
58+
response_flat = response.reshape(-1)
59+
60+
# Basic cross-entropy loss (simplified)
61+
loss = torch.nn.functional.cross_entropy(
62+
logits_flat, response_flat, reduction="none"
63+
).view(batch_size, response_len)
64+
65+
# Apply padding mask and reduce
66+
masked_loss = loss * padding_mask
67+
loss = masked_loss.sum() / padding_mask.sum().clamp(min=1.0)
68+
69+
return loss
70+
71+
72+
def generate_random_batch(
73+
batch_size: int,
74+
request_len: int,
75+
response_len: int,
76+
vocab_size: int = 32000,
77+
device: str = "cuda",
78+
dp_size: int = 1,
79+
):
80+
"""
81+
Generate random input and target tensors matching GRPO data format
82+
Creates one batch per data parallel rank
83+
"""
84+
inputs = []
85+
targets = []
86+
87+
# Create one batch for each data parallel rank
88+
for _ in range(dp_size):
89+
request = torch.randint(
90+
1, vocab_size, (batch_size, request_len), dtype=torch.long, device=device
91+
)
92+
response = torch.randint(
93+
1, vocab_size, (batch_size, response_len), dtype=torch.long, device=device
94+
)
95+
96+
# Create padding mask (randomly mask some tokens as padding)
97+
padding_mask = torch.rand((batch_size, response_len), device=device) > 0.1
2798

99+
ref_logprobs = (
100+
-torch.abs(torch.randn((batch_size, response_len), device=device)) - 1.0
101+
)
102+
advantages = torch.randn((batch_size, 1), device=device)
103+
input_tokens = torch.cat([request, response], dim=1)
104+
inputs.append({"tokens": input_tokens})
105+
targets.append(
106+
{
107+
"response": response,
108+
"ref_logprobs": ref_logprobs,
109+
"advantages": advantages,
110+
"padding_mask": padding_mask,
111+
}
112+
)
28113

29-
def placeholder_loss_function(logits, targets):
30-
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
114+
return inputs, targets
31115

32116

33117
async def main(cfg: DictConfig):
34-
"""Main function that only initializes the trainer."""
118+
"""
119+
Trainer simulation app for memory/CPU profiling and system usage analysis.
120+
121+
This app initializes only the RLTrainer component and runs a training loop with
122+
synthetic random data to simulate real trainer system usage patterns. It is
123+
designed for:
124+
125+
- Memory profiling of trainer infrastructure
126+
- CPU usage analysis during training steps
127+
- System resource monitoring (GPU memory, network, etc.)
128+
- Performance benchmarking of trainer components
129+
- Testing trainer stability under load
130+
131+
The app uses the same configuration format as GRPO but bypasses policy generation,
132+
replay buffers, and reward computation, focusing purely on the trainer's
133+
computational and memory characteristics with realistic data shapes.
134+
"""
135+
136+
# Extract training parameters from existing GRPO config fields
137+
batch_size = cfg.get("batch_size", 4)
138+
request_len = cfg.get("max_req_tokens", 128)
139+
response_len = cfg.get("max_res_tokens", 128)
140+
max_training_steps = cfg.trainer.training.get("steps", 100)
141+
142+
# Get vocab size from the actual model tokenizer
143+
model_name = cfg.get("model")
144+
print(f"Loading tokenizer for model: {model_name}")
145+
tokenizer = get_tokenizer(model_name)
146+
vocab_size = tokenizer.vocab_size
147+
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
148+
print(f"Detected vocab size: {vocab_size}, pad token ID: {pad_id}")
149+
150+
# Get data parallel size from replay buffer config (which matches trainer DP degree)
151+
dp_size = cfg.get("replay_buffer", {}).get("dp_size", 1)
152+
if dp_size is None:
153+
# Fallback to trainer config if replay_buffer.dp_size not set
154+
trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1)
155+
dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1
35156

36-
# Initialize provisioner
37157
await init_provisioner(
38158
ProvisionerConfig(
39159
launcher_config=LauncherConfig(
@@ -45,29 +165,65 @@ async def main(cfg: DictConfig):
45165
)
46166
)
47167

48-
# Initialize metric logging
49168
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
50169
mlogger = await get_or_create_metric_logger()
51170
await mlogger.init_backends.call_one(metric_logging_cfg)
52171

172+
await ts.initialize(strategy=ts.ControllerStorageVolumes())
53173
# Initialize trainer only
54174
print("Initializing trainer...")
55175
trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor(
56-
**cfg.trainer, loss=placeholder_loss_function
176+
**cfg.trainer, loss=simple_grpo_loss
57177
)
58-
59178
print("Trainer initialized successfully!")
60-
print(f"Trainer configuration: {cfg.trainer}")
179+
print(f"Training configuration:")
180+
print(f" - Batch size: {batch_size}")
181+
print(f" - Request length: {request_len}")
182+
print(f" - Response length: {response_len}")
183+
print(f" - Vocab size: {vocab_size}")
184+
print(f" - Data parallel size: {dp_size}")
185+
print(f" - Max training steps: {max_training_steps}")
186+
187+
async def continuous_training():
188+
training_step = 0
189+
190+
print("Starting training loop with random data...")
191+
while training_step < max_training_steps:
192+
t = Tracer("trainer/continuous_training")
193+
t.start()
194+
195+
inputs, targets = generate_random_batch(
196+
batch_size=batch_size,
197+
request_len=request_len,
198+
response_len=response_len,
199+
vocab_size=vocab_size,
200+
dp_size=dp_size,
201+
)
202+
t.step("generate_random_data")
203+
204+
# Perform training step
205+
await trainer.train_step.call(inputs, targets)
206+
training_step += 1
207+
t.step("train_step")
208+
209+
await trainer.push_weights.call(training_step)
210+
t.step("push_weights")
211+
t.stop()
212+
213+
# Flush metrics
214+
await mlogger.flush.call_one(training_step)
215+
216+
print(f"Completed training step {training_step}/{max_training_steps}")
217+
218+
# Sleep between steps to avoid overwhelming the system
219+
await asyncio.sleep(1.0)
61220

62-
# Keep the trainer running for demonstration
63-
# In a real scenario, you might want to expose endpoints or do other work here
64221
try:
65-
print("Trainer is running. Press Ctrl+C to shutdown...")
66-
while True:
67-
await asyncio.sleep(1)
222+
await continuous_training()
68223
except KeyboardInterrupt:
69-
print("Shutting down trainer...")
224+
print("Training interrupted by user")
70225
finally:
226+
print("Shutting down trainer...")
71227
await RLTrainer.shutdown(trainer)
72228
await mlogger.shutdown.call_one()
73229
await shutdown()

0 commit comments

Comments
 (0)