diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index 1441bb9e3..6d71ddf69 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -40,17 +40,17 @@ def simple_grpo_loss( Just performs basic tensor operations to simulate memory usage. """ # Extract dimensions - batch_size, response_len = response.shape + local_batch_size, response_len = response.shape vocab_size = logits.size(-1) full_seq_len = logits.size(1) # Extract only the response portion from logits - # logits shape: [batch_size, request_len + response_len, vocab_size] + # logits shape: [local_batch_size, request_len + response_len, vocab_size] # We want the last response_len tokens request_len = full_seq_len - response_len response_logits = logits[ :, request_len:, : - ] # [batch_size, response_len, vocab_size] + ] # [local_batch_size, response_len, vocab_size] # Flatten logits and response for cross-entropy logits_flat = response_logits.reshape(-1, vocab_size) @@ -59,7 +59,7 @@ def simple_grpo_loss( # Basic cross-entropy loss (simplified) loss = torch.nn.functional.cross_entropy( logits_flat, response_flat, reduction="none" - ).view(batch_size, response_len) + ).view(local_batch_size, response_len) # Apply padding mask and reduce masked_loss = loss * padding_mask @@ -69,7 +69,7 @@ def simple_grpo_loss( def generate_random_batch( - batch_size: int, + local_batch_size: int, request_len: int, response_len: int, vocab_size: int = 32000, @@ -86,19 +86,28 @@ def generate_random_batch( # Create one batch for each data parallel rank for _ in range(dp_size): request = torch.randint( - 1, vocab_size, (batch_size, request_len), dtype=torch.long, device=device + 1, + vocab_size, + (local_batch_size, request_len), + dtype=torch.long, + device=device, ) response = torch.randint( - 1, vocab_size, (batch_size, response_len), dtype=torch.long, device=device + 1, + vocab_size, + (local_batch_size, response_len), + dtype=torch.long, + device=device, ) # Create padding mask (randomly mask some tokens as padding) - padding_mask = torch.rand((batch_size, response_len), device=device) > 0.1 + padding_mask = torch.rand((local_batch_size, response_len), device=device) > 0.1 ref_logprobs = ( - -torch.abs(torch.randn((batch_size, response_len), device=device)) - 1.0 + -torch.abs(torch.randn((local_batch_size, response_len), device=device)) + - 1.0 ) - advantages = torch.randn((batch_size, 1), device=device) + advantages = torch.randn((local_batch_size, 1), device=device) input_tokens = torch.cat([request, response], dim=1) inputs.append({"tokens": input_tokens}) targets.append( @@ -133,7 +142,9 @@ async def main(cfg: DictConfig): """ # Extract training parameters from existing GRPO config fields - batch_size = cfg.get("batch_size", 4) + local_batch_size = cfg.get("local_batch_size", None) + assert local_batch_size is not None, "local_batch_size must be specified" + request_len = cfg.get("max_req_tokens", 128) response_len = cfg.get("max_res_tokens", 128) max_training_steps = cfg.trainer.training.get("steps", 100) @@ -156,7 +167,7 @@ async def main(cfg: DictConfig): await init_provisioner( ProvisionerConfig( launcher_config=LauncherConfig( - launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)), + launcher=cfg.get(LAUNCHER_KEY, Launcher.SLURM.value), job_name=cfg.get(JOB_NAME_KEY, None), services={k: ServiceConfig(**v) for k, v in cfg.services.items()}, actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()}, @@ -175,7 +186,7 @@ async def main(cfg: DictConfig): **cfg.trainer, loss=simple_grpo_loss ) print("Trainer initialized successfully with following configs!") - print(f" - Batch size: {batch_size}") + print(f" - Local batch size: {local_batch_size}") print(f" - Request length: {request_len}") print(f" - Response length: {response_len}") print(f" - Vocab size: {vocab_size}") @@ -191,7 +202,7 @@ async def continuous_training(): t.start() inputs, targets = generate_random_batch( - batch_size=batch_size, + local_batch_size=local_batch_size, request_len=request_len, response_len=response_len, vocab_size=vocab_size,