From 8d3afdac1685c93b390cf2f8753e1fea24c74fbe Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 3 Oct 2025 21:55:44 -0400 Subject: [PATCH 1/4] initial commit --- apps/trainer/main.py | 83 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 apps/trainer/main.py diff --git a/apps/trainer/main.py b/apps/trainer/main.py new file mode 100644 index 000000000..d27079d61 --- /dev/null +++ b/apps/trainer/main.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: python -m apps.trainer.main --config apps/trainer/trainer_config.yaml + +import asyncio + +import torch +import torch.nn.functional as F + +from forge.actors.trainer import RLTrainer +from forge.cli.config import parse +from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY +from forge.controller.provisioner import init_provisioner, shutdown +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.types import ( + Launcher, + LauncherConfig, + ProcessConfig, + ProvisionerConfig, + ServiceConfig, +) +from omegaconf import DictConfig + + +def placeholder_loss_function(logits, targets): + return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + +async def main(cfg: DictConfig): + """Main function that only initializes the trainer.""" + + # Initialize provisioner + await init_provisioner( + ProvisionerConfig( + launcher_config=LauncherConfig( + launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)), + job_name=cfg.get(JOB_NAME_KEY, None), + services={k: ServiceConfig(**v) for k, v in cfg.services.items()}, + actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()}, + ) + ) + ) + + # Initialize metric logging + metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + mlogger = await get_or_create_metric_logger() + await mlogger.init_backends.call_one(metric_logging_cfg) + + # Initialize trainer only + print("Initializing trainer...") + trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, loss=placeholder_loss_function + ) + + print("Trainer initialized successfully!") + print(f"Trainer configuration: {cfg.trainer}") + + # Keep the trainer running for demonstration + # In a real scenario, you might want to expose endpoints or do other work here + try: + print("Trainer is running. Press Ctrl+C to shutdown...") + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + print("Shutting down trainer...") + finally: + await RLTrainer.shutdown(trainer) + await mlogger.shutdown.call_one() + await shutdown() + print("Trainer shutdown complete.") + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() # @parse grabs the cfg from CLI From d7d89d56f74c244a023bf648be555654bcc720c7 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 3 Oct 2025 22:39:00 -0400 Subject: [PATCH 2/4] update --- apps/trainer/main.py | 190 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 173 insertions(+), 17 deletions(-) diff --git a/apps/trainer/main.py b/apps/trainer/main.py index d27079d61..a05d87f01 100644 --- a/apps/trainer/main.py +++ b/apps/trainer/main.py @@ -4,18 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Usage: python -m apps.trainer.main --config apps/trainer/trainer_config.yaml +# Usage: python -m apps.trainer.main --config apps/grpo/qwen3_32b.yaml import asyncio import torch -import torch.nn.functional as F - +import torchstore as ts from forge.actors.trainer import RLTrainer from forge.cli.config import parse from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY from forge.controller.provisioner import init_provisioner, shutdown from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer from forge.types import ( Launcher, LauncherConfig, @@ -24,16 +25,135 @@ ServiceConfig, ) from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +def simple_grpo_loss( + logits: torch.Tensor, + response: torch.Tensor, + ref_logprobs: torch.Tensor, + advantages: torch.Tensor, + padding_mask: torch.Tensor, + beta: float = 0.1, +) -> torch.Tensor: + """ + Simplified loss function for memory/CPU profiling purposes. + Just performs basic tensor operations to simulate memory usage. + """ + # Extract dimensions + batch_size, response_len = response.shape + vocab_size = logits.size(-1) + full_seq_len = logits.size(1) + + # Extract only the response portion from logits + # logits shape: [batch_size, request_len + response_len, vocab_size] + # We want the last response_len tokens + request_len = full_seq_len - response_len + response_logits = logits[ + :, request_len:, : + ] # [batch_size, response_len, vocab_size] + + # Flatten logits and response for cross-entropy + logits_flat = response_logits.reshape(-1, vocab_size) + response_flat = response.reshape(-1) + + # Basic cross-entropy loss (simplified) + loss = torch.nn.functional.cross_entropy( + logits_flat, response_flat, reduction="none" + ).view(batch_size, response_len) + + # Apply padding mask and reduce + masked_loss = loss * padding_mask + loss = masked_loss.sum() / padding_mask.sum().clamp(min=1.0) + + return loss + + +def generate_random_batch( + batch_size: int, + request_len: int, + response_len: int, + vocab_size: int = 32000, + device: str = "cuda", + dp_size: int = 1, +): + """ + Generate random input and target tensors matching GRPO data format + Creates one batch per data parallel rank + """ + inputs = [] + targets = [] + + # Create one batch for each data parallel rank + for _ in range(dp_size): + request = torch.randint( + 1, vocab_size, (batch_size, request_len), dtype=torch.long, device=device + ) + response = torch.randint( + 1, vocab_size, (batch_size, response_len), dtype=torch.long, device=device + ) + + # Create padding mask (randomly mask some tokens as padding) + padding_mask = torch.rand((batch_size, response_len), device=device) > 0.1 + ref_logprobs = ( + -torch.abs(torch.randn((batch_size, response_len), device=device)) - 1.0 + ) + advantages = torch.randn((batch_size, 1), device=device) + input_tokens = torch.cat([request, response], dim=1) + inputs.append({"tokens": input_tokens}) + targets.append( + { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": padding_mask, + } + ) -def placeholder_loss_function(logits, targets): - return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + return inputs, targets async def main(cfg: DictConfig): - """Main function that only initializes the trainer.""" + """ + Trainer simulation app for memory/CPU profiling and system usage analysis. + + This app initializes only the RLTrainer component and runs a training loop with + synthetic random data to simulate real trainer system usage patterns. It is + designed for: + + - Memory profiling of trainer infrastructure + - CPU usage analysis during training steps + - System resource monitoring (GPU memory, network, etc.) + - Performance benchmarking of trainer components + - Testing trainer stability under load + + The app uses the same configuration format as GRPO but bypasses policy generation, + replay buffers, and reward computation, focusing purely on the trainer's + computational and memory characteristics with realistic data shapes. + """ + + # Extract training parameters from existing GRPO config fields + batch_size = cfg.get("batch_size", 4) + request_len = cfg.get("max_req_tokens", 128) + response_len = cfg.get("max_res_tokens", 128) + max_training_steps = cfg.trainer.training.get("steps", 100) + + # Get vocab size from the actual model tokenizer + model_name = cfg.get("model") + print(f"Loading tokenizer for model: {model_name}") + tokenizer = get_tokenizer(model_name) + vocab_size = tokenizer.vocab_size + pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + print(f"Detected vocab size: {vocab_size}, pad token ID: {pad_id}") + + # Get data parallel size from replay buffer config (which matches trainer DP degree) + dp_size = cfg.get("replay_buffer", {}).get("dp_size", 1) + if dp_size is None: + # Fallback to trainer config if replay_buffer.dp_size not set + trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) + dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 - # Initialize provisioner await init_provisioner( ProvisionerConfig( launcher_config=LauncherConfig( @@ -45,29 +165,65 @@ async def main(cfg: DictConfig): ) ) - # Initialize metric logging metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(metric_logging_cfg) + await ts.initialize(strategy=ts.ControllerStorageVolumes()) # Initialize trainer only print("Initializing trainer...") trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor( - **cfg.trainer, loss=placeholder_loss_function + **cfg.trainer, loss=simple_grpo_loss ) - print("Trainer initialized successfully!") - print(f"Trainer configuration: {cfg.trainer}") + print(f"Training configuration:") + print(f" - Batch size: {batch_size}") + print(f" - Request length: {request_len}") + print(f" - Response length: {response_len}") + print(f" - Vocab size: {vocab_size}") + print(f" - Data parallel size: {dp_size}") + print(f" - Max training steps: {max_training_steps}") + + async def continuous_training(): + training_step = 0 + + print("Starting training loop with random data...") + while training_step < max_training_steps: + t = Tracer("trainer/continuous_training") + t.start() + + inputs, targets = generate_random_batch( + batch_size=batch_size, + request_len=request_len, + response_len=response_len, + vocab_size=vocab_size, + dp_size=dp_size, + ) + t.step("generate_random_data") + + # Perform training step + await trainer.train_step.call(inputs, targets) + training_step += 1 + t.step("train_step") + + await trainer.push_weights.call(training_step) + t.step("push_weights") + t.stop() + + # Flush metrics + await mlogger.flush.call_one(training_step) + + print(f"Completed training step {training_step}/{max_training_steps}") + + # Sleep between steps to avoid overwhelming the system + await asyncio.sleep(1.0) - # Keep the trainer running for demonstration - # In a real scenario, you might want to expose endpoints or do other work here try: - print("Trainer is running. Press Ctrl+C to shutdown...") - while True: - await asyncio.sleep(1) + await continuous_training() except KeyboardInterrupt: - print("Shutting down trainer...") + print("Training interrupted by user") finally: + print("Shutting down trainer...") await RLTrainer.shutdown(trainer) await mlogger.shutdown.call_one() await shutdown() From 26d0cb38eb7c0dfaf785b998cb37e0cd4e934ce1 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 3 Oct 2025 22:47:21 -0400 Subject: [PATCH 3/4] format --- apps/trainer/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/apps/trainer/main.py b/apps/trainer/main.py index a05d87f01..0f8036777 100644 --- a/apps/trainer/main.py +++ b/apps/trainer/main.py @@ -15,7 +15,6 @@ from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY from forge.controller.provisioner import init_provisioner, shutdown from forge.observability.metric_actors import get_or_create_metric_logger -from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ( Launcher, @@ -175,8 +174,7 @@ async def main(cfg: DictConfig): trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor( **cfg.trainer, loss=simple_grpo_loss ) - print("Trainer initialized successfully!") - print(f"Training configuration:") + print("Trainer initialized successfully with following configs!") print(f" - Batch size: {batch_size}") print(f" - Request length: {request_len}") print(f" - Response length: {response_len}") From 4f287ec311ce5139e6d3ee53b44decbb8dbf755e Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Mon, 6 Oct 2025 09:05:45 -0400 Subject: [PATCH 4/4] rename --- apps/{trainer => rl_trainer}/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename apps/{trainer => rl_trainer}/main.py (99%) diff --git a/apps/trainer/main.py b/apps/rl_trainer/main.py similarity index 99% rename from apps/trainer/main.py rename to apps/rl_trainer/main.py index 0f8036777..8473cc16d 100644 --- a/apps/trainer/main.py +++ b/apps/rl_trainer/main.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Usage: python -m apps.trainer.main --config apps/grpo/qwen3_32b.yaml +# Usage: python -m apps.rl_trainer.main --config apps/grpo/qwen3_32b.yaml import asyncio