diff --git a/tests/sandbox/weight_sync/main.py b/tests/sandbox/weight_sync/main.py new file mode 100644 index 000000000..dfdc58f0a --- /dev/null +++ b/tests/sandbox/weight_sync/main.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Weight Sync Sandbox + +A minimal test environment focused exclusively on testing the weight synchronization +mechanism between RLTrainer and Generator. + +Usage: + python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml +""" + +import asyncio +import time + +import torch +import torchstore as ts +from forge.actors._torchstore_utils import rdma_enabled +from forge.actors.generator import Generator +from forge.actors.trainer import RLTrainer +from forge.controller.provisioner import init_provisioner, shutdown +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +def generate_random_batch( + local_batch_size: int, + request_len: int, + response_len: int, + vocab_size: int = 32000, + device: str = "cuda", + dp_size: int = 1, +): + """ + Generate random input and target tensors for a single training step. + Creates one batch per data parallel rank. + """ + inputs = [] + targets = [] + + # Create one batch for each data parallel rank + for _ in range(dp_size): + request = torch.randint( + 1, + vocab_size, + (local_batch_size, request_len), + dtype=torch.long, + device=device, + ) + response = torch.randint( + 1, + vocab_size, + (local_batch_size, response_len), + dtype=torch.long, + device=device, + ) + + # Create padding mask + padding_mask = torch.rand((local_batch_size, response_len), device=device) > 0.1 + + ref_logprobs = ( + -torch.abs(torch.randn((local_batch_size, response_len), device=device)) + - 1.0 + ) + advantages = torch.randn((local_batch_size, 1), device=device) + input_tokens = torch.cat([request, response], dim=1) + inputs.append({"tokens": input_tokens}) + targets.append( + { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": padding_mask, + } + ) + + return inputs, targets + + +async def main(cfg: DictConfig): + local_batch_size = cfg.get("local_batch_size", None) + assert local_batch_size is not None, "local_batch_size must be specified" + + request_len = cfg.get("max_req_tokens", 64) + response_len = cfg.get("max_res_tokens", 64) + model_name = cfg.get("model") + + print(f"Loading tokenizer for model: {model_name}") + tokenizer = get_tokenizer(model_name) + vocab_size = tokenizer.vocab_size + print(f"Detected vocab size: {vocab_size}") + + trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) + dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 + + # ---- Global setups ---- # + provisioner = None + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + + metric_logging_cfg = cfg.get("metric_logging", {}) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + + # Initialize torchstore + await ts.initialize(strategy=ts.ControllerStorageVolumes()) + + print("=" * 80) + print(f"Model: {model_name}") + print(f"Local batch size: {local_batch_size}") + print( + f"Sequence length: {request_len + response_len} ({request_len} + {response_len})" + ) + print(f"Data parallel size: {dp_size}") + print(f"Is RDMA available? {rdma_enabled()}") + print("=" * 80 + "\n") + + # Initialize trainer and generator + print("Initializing trainer and generator...") + init_start = time.time() + + trainer, policy = await asyncio.gather( + RLTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, + loss=lambda *args, **kwargs: torch.tensor( + 1.0, requires_grad=True, device="cuda" + ), + ), + Generator.options(**cfg.actors.policy).as_actor(**cfg.policy), + ) + + init_time = time.time() - init_start + print(f"Finished initialization in ({init_time:.2f}s)") + + # Run one training step to create weight delta + print("Running single training step...") + step_start = time.time() + + inputs, targets = generate_random_batch( + local_batch_size=local_batch_size, + request_len=request_len, + response_len=response_len, + vocab_size=vocab_size, + dp_size=dp_size, + ) + + await trainer.train_step.call(inputs, targets) + step_time = time.time() - step_start + print(f"Finished train step in ({step_time:.2f}s)\n") + + # Test push_weights + print("Pushing weights from trainer to store...") + push_start = time.time() + + await trainer.push_weights.call(policy_version=1) + + push_time = time.time() - push_start + print(f"Finished weights push in ({push_time:.2f}s)\n") + + # Test update_weights + print("Updating generator weights from store...") + update_start = time.time() + + await policy.update_weights.call(version=1) + + update_time = time.time() - update_start + print(f"Updated generator weights ({update_time:.2f}s)\n") + + # TODO - ideally we have the capability to check forward passes between + # the trainer/generator to verify correctness. This would require adding + # forward capabilities to both trainer/generator actors. + + # Summary + print("=" * 80) + print("Results") + print("=" * 80) + print(f"Push time: {push_time:.2f}s") + print(f"Update time: {update_time:.2f}s") + print(f"Total sync time: {push_time + update_time:.2f}s") + print("=" * 80 + "\n") + + # Cleanup + print("Shutting down...") + await shutdown() + print("Shutdown complete.") + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() diff --git a/tests/sandbox/weight_sync/qwen3_1_7b.yaml b/tests/sandbox/weight_sync/qwen3_1_7b.yaml new file mode 100644 index 000000000..ea28a1471 --- /dev/null +++ b/tests/sandbox/weight_sync/qwen3_1_7b.yaml @@ -0,0 +1,74 @@ +# Weight Sync Sandbox Configuration +# >>> python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml + +model: "Qwen/Qwen3-1.7B" +local_batch_size: 4 +max_req_tokens: 64 +max_res_tokens: 64 + +metric_logging: + console: + logging_mode: global_reduce + +policy: + prefetch_weights_to_shm: false # Disable to avoid shared memory warnings in test + engine_args: + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: true + sampling_params: + n: 1 + max_tokens: 32 # Just for verification forward pass + temperature: 1.0 + top_p: 1.0 + +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${local_batch_size} + seq_len: 128 # max_req_tokens + max_res_tokens + max_norm: 1.0 + steps: 1 # We only run 1 step + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 # Single GPU, no FSDP + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: ./checkpoint + initial_load_path: hf://${model} + initial_load_in_hf: true + last_save_in_hf: true + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Resource allocation - both as actors +actors: + policy: + procs: 1 # Single process for generator + with_gpus: true + mesh_name: policy + trainer: + procs: 1 # Single process for trainer + with_gpus: true + mesh_name: trainer