-
Notifications
You must be signed in to change notification settings - Fork 16
[Debugging] Create a stand alone trainer app #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Usage: python -m apps.trainer.main --config apps/grpo/qwen3_32b.yaml | ||
|
||
import asyncio | ||
|
||
import torch | ||
import torchstore as ts | ||
from forge.actors.trainer import RLTrainer | ||
from forge.cli.config import parse | ||
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY | ||
from forge.controller.provisioner import init_provisioner, shutdown | ||
from forge.observability.metric_actors import get_or_create_metric_logger | ||
from forge.observability.perf_tracker import Tracer | ||
from forge.types import ( | ||
Launcher, | ||
LauncherConfig, | ||
ProcessConfig, | ||
ProvisionerConfig, | ||
ServiceConfig, | ||
) | ||
from omegaconf import DictConfig | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
|
||
def simple_grpo_loss( | ||
logits: torch.Tensor, | ||
response: torch.Tensor, | ||
ref_logprobs: torch.Tensor, | ||
advantages: torch.Tensor, | ||
padding_mask: torch.Tensor, | ||
beta: float = 0.1, | ||
) -> torch.Tensor: | ||
""" | ||
Simplified loss function for memory/CPU profiling purposes. | ||
Just performs basic tensor operations to simulate memory usage. | ||
""" | ||
# Extract dimensions | ||
batch_size, response_len = response.shape | ||
vocab_size = logits.size(-1) | ||
full_seq_len = logits.size(1) | ||
|
||
# Extract only the response portion from logits | ||
# logits shape: [batch_size, request_len + response_len, vocab_size] | ||
# We want the last response_len tokens | ||
request_len = full_seq_len - response_len | ||
response_logits = logits[ | ||
:, request_len:, : | ||
] # [batch_size, response_len, vocab_size] | ||
|
||
# Flatten logits and response for cross-entropy | ||
logits_flat = response_logits.reshape(-1, vocab_size) | ||
response_flat = response.reshape(-1) | ||
|
||
# Basic cross-entropy loss (simplified) | ||
loss = torch.nn.functional.cross_entropy( | ||
logits_flat, response_flat, reduction="none" | ||
).view(batch_size, response_len) | ||
|
||
# Apply padding mask and reduce | ||
masked_loss = loss * padding_mask | ||
loss = masked_loss.sum() / padding_mask.sum().clamp(min=1.0) | ||
|
||
return loss | ||
|
||
|
||
def generate_random_batch( | ||
batch_size: int, | ||
request_len: int, | ||
response_len: int, | ||
vocab_size: int = 32000, | ||
device: str = "cuda", | ||
dp_size: int = 1, | ||
): | ||
""" | ||
Generate random input and target tensors matching GRPO data format | ||
Creates one batch per data parallel rank | ||
""" | ||
inputs = [] | ||
targets = [] | ||
|
||
# Create one batch for each data parallel rank | ||
for _ in range(dp_size): | ||
request = torch.randint( | ||
1, vocab_size, (batch_size, request_len), dtype=torch.long, device=device | ||
) | ||
response = torch.randint( | ||
1, vocab_size, (batch_size, response_len), dtype=torch.long, device=device | ||
) | ||
|
||
# Create padding mask (randomly mask some tokens as padding) | ||
padding_mask = torch.rand((batch_size, response_len), device=device) > 0.1 | ||
|
||
ref_logprobs = ( | ||
-torch.abs(torch.randn((batch_size, response_len), device=device)) - 1.0 | ||
) | ||
advantages = torch.randn((batch_size, 1), device=device) | ||
input_tokens = torch.cat([request, response], dim=1) | ||
inputs.append({"tokens": input_tokens}) | ||
targets.append( | ||
{ | ||
"response": response, | ||
"ref_logprobs": ref_logprobs, | ||
"advantages": advantages, | ||
"padding_mask": padding_mask, | ||
} | ||
) | ||
|
||
return inputs, targets | ||
|
||
|
||
async def main(cfg: DictConfig): | ||
""" | ||
Trainer simulation app for memory/CPU profiling and system usage analysis. | ||
|
||
This app initializes only the RLTrainer component and runs a training loop with | ||
synthetic random data to simulate real trainer system usage patterns. It is | ||
designed for: | ||
|
||
- Memory profiling of trainer infrastructure | ||
- CPU usage analysis during training steps | ||
- System resource monitoring (GPU memory, network, etc.) | ||
- Performance benchmarking of trainer components | ||
- Testing trainer stability under load | ||
|
||
The app uses the same configuration format as GRPO but bypasses policy generation, | ||
replay buffers, and reward computation, focusing purely on the trainer's | ||
computational and memory characteristics with realistic data shapes. | ||
""" | ||
|
||
# Extract training parameters from existing GRPO config fields | ||
batch_size = cfg.get("batch_size", 4) | ||
request_len = cfg.get("max_req_tokens", 128) | ||
response_len = cfg.get("max_res_tokens", 128) | ||
max_training_steps = cfg.trainer.training.get("steps", 100) | ||
|
||
# Get vocab size from the actual model tokenizer | ||
model_name = cfg.get("model") | ||
print(f"Loading tokenizer for model: {model_name}") | ||
tokenizer = get_tokenizer(model_name) | ||
vocab_size = tokenizer.vocab_size | ||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | ||
print(f"Detected vocab size: {vocab_size}, pad token ID: {pad_id}") | ||
|
||
# Get data parallel size from replay buffer config (which matches trainer DP degree) | ||
dp_size = cfg.get("replay_buffer", {}).get("dp_size", 1) | ||
if dp_size is None: | ||
# Fallback to trainer config if replay_buffer.dp_size not set | ||
trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) | ||
dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 | ||
|
||
await init_provisioner( | ||
ProvisionerConfig( | ||
launcher_config=LauncherConfig( | ||
launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)), | ||
job_name=cfg.get(JOB_NAME_KEY, None), | ||
services={k: ServiceConfig(**v) for k, v in cfg.services.items()}, | ||
actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()}, | ||
) | ||
) | ||
) | ||
|
||
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) | ||
mlogger = await get_or_create_metric_logger() | ||
await mlogger.init_backends.call_one(metric_logging_cfg) | ||
|
||
await ts.initialize(strategy=ts.ControllerStorageVolumes()) | ||
# Initialize trainer only | ||
print("Initializing trainer...") | ||
trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor( | ||
**cfg.trainer, loss=simple_grpo_loss | ||
) | ||
print("Trainer initialized successfully with following configs!") | ||
print(f" - Batch size: {batch_size}") | ||
print(f" - Request length: {request_len}") | ||
print(f" - Response length: {response_len}") | ||
print(f" - Vocab size: {vocab_size}") | ||
print(f" - Data parallel size: {dp_size}") | ||
print(f" - Max training steps: {max_training_steps}") | ||
|
||
async def continuous_training(): | ||
training_step = 0 | ||
|
||
print("Starting training loop with random data...") | ||
while training_step < max_training_steps: | ||
t = Tracer("trainer/continuous_training") | ||
t.start() | ||
|
||
inputs, targets = generate_random_batch( | ||
batch_size=batch_size, | ||
request_len=request_len, | ||
response_len=response_len, | ||
vocab_size=vocab_size, | ||
dp_size=dp_size, | ||
) | ||
t.step("generate_random_data") | ||
|
||
# Perform training step | ||
await trainer.train_step.call(inputs, targets) | ||
training_step += 1 | ||
t.step("train_step") | ||
|
||
await trainer.push_weights.call(training_step) | ||
t.step("push_weights") | ||
t.stop() | ||
|
||
# Flush metrics | ||
await mlogger.flush.call_one(training_step) | ||
|
||
print(f"Completed training step {training_step}/{max_training_steps}") | ||
|
||
# Sleep between steps to avoid overwhelming the system | ||
await asyncio.sleep(1.0) | ||
|
||
try: | ||
await continuous_training() | ||
except KeyboardInterrupt: | ||
print("Training interrupted by user") | ||
finally: | ||
print("Shutting down trainer...") | ||
await RLTrainer.shutdown(trainer) | ||
await mlogger.shutdown.call_one() | ||
await shutdown() | ||
print("Trainer shutdown complete.") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
@parse | ||
def _main(cfg): | ||
asyncio.run(main(cfg)) | ||
|
||
_main() # @parse grabs the cfg from CLI |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, can we re-name this to rl_trainer? The reason is because we have sft, sft_v2 and then if we have a separate trainer it may confuse even further lol