From 404484ec27d83247319a1735d30147830bd97681 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Wed, 5 Nov 2025 14:39:12 -0800 Subject: [PATCH 1/6] weight sync sandbox --- .gitignore | 3 + tests/sandbox/weight_sync/README.md | 105 ++++++++ tests/sandbox/weight_sync/main.py | 286 ++++++++++++++++++++++ tests/sandbox/weight_sync/qwen3_1_7b.yaml | 82 +++++++ 4 files changed, 476 insertions(+) create mode 100644 tests/sandbox/weight_sync/README.md create mode 100644 tests/sandbox/weight_sync/main.py create mode 100644 tests/sandbox/weight_sync/qwen3_1_7b.yaml diff --git a/.gitignore b/.gitignore index c952405d6..d5a139050 100644 --- a/.gitignore +++ b/.gitignore @@ -204,3 +204,6 @@ demo_top_down.md # enroot / sqsh *.sqsh + +# claude magic +.claude diff --git a/tests/sandbox/weight_sync/README.md b/tests/sandbox/weight_sync/README.md new file mode 100644 index 000000000..5f3bfa923 --- /dev/null +++ b/tests/sandbox/weight_sync/README.md @@ -0,0 +1,105 @@ +# Weight Sync Sandbox + +A minimal test environment focused exclusively on testing the weight synchronization mechanism between `RLTrainer` and `Generator`. + +## Purpose + +This sandbox tests the complete weight sync pipeline in isolation, without the complexity of a full RL training loop. It's designed for: + +- **Debugging weight sync issues**: Isolate and test push/pull mechanisms +- **Performance profiling**: Measure push_weights() and update_weights() latency +- **Mode verification**: Test both DCP (filesystem) and Direct (RDMA) sync modes +- **Quick validation**: Verify weight sync works before running full training + +## What It Tests + +### Push Weights (`RLTrainer.push_weights()`) +- Converts TorchTitan state dict to HuggingFace format +- Saves to torchstore via DCP (no RDMA) or direct (with RDMA) +- Measures push performance + +### Update Weights (`Generator.update_weights()`) +- Stops accepting new requests (lock mechanism) +- Waits for pending requests to complete +- Fetches weights from torchstore +- Loads weights into vLLM model +- Resumes accepting requests +- Measures update performance + +### Verification +- Runs forward pass with updated weights to confirm success + +## Usage + +```bash +python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml +``` + +## Test Flow + +1. **Initialize** trainer and generator with same model (Qwen3-1.7B) +2. **Run ONE training step** to create a weight delta +3. **Push weights** from trainer to torchstore (version 1) +4. **Update weights** in generator from torchstore (version 1) +5. **Verify** with forward pass to confirm weights loaded correctly + +## Output + +The sandbox prints: +- Sync mode used (DCP vs Direct/RDMA) +- Time for push_weights() +- Time for update_weights() +- Total sync time +- Sample generation output for verification + +Example output: +``` +================================================================================ +WEIGHT SYNC SANDBOX +================================================================================ +Model: Qwen/Qwen3-1.7B +RDMA available: False +Sync mode: DCP (Filesystem) +================================================================================ + +✓ Initialization complete (12.34s) + +[1/4] Running single training step to create weight delta... +✓ Training step complete (0.56s) + +[2/4] Testing push_weights() to torchstore... +✓ Pushed weights to torchstore (2.13s) + +[3/4] Testing update_weights() from torchstore... +✓ Updated weights in generator (1.87s) + +[4/4] Verification: Running forward pass with updated weights... +✓ Forward pass successful (0.23s) + +================================================================================ +WEIGHT SYNC TEST COMPLETE +================================================================================ +Push time: 2.13s +Update time: 1.87s +Total sync time: 4.00s +Sync mode used: DCP (Filesystem) +================================================================================ +``` + +## Configuration + +Uses Qwen3-1.7B for fast testing with minimal resource requirements: +- **Model**: Qwen/Qwen3-1.7B (small, fast to load) +- **Batch size**: 4 (minimal overhead) +- **Sequence length**: 128 tokens (64 request + 64 response) +- **Generator**: Single process actor (not service) +- **Trainer**: Single process, no FSDP +- **Data parallel**: 1 (single GPU) + +## Key Differences from Other Sandboxes + +- **vs vllm sandbox**: Adds trainer and tests weight updates +- **vs rl_trainer sandbox**: Adds generator and tests weight loading +- **vs toy_rl sandbox**: Focuses purely on weight sync, not full RL loop + +This is the **only sandbox that tests the complete weight sync mechanism** in isolation. diff --git a/tests/sandbox/weight_sync/main.py b/tests/sandbox/weight_sync/main.py new file mode 100644 index 000000000..de02d6b9e --- /dev/null +++ b/tests/sandbox/weight_sync/main.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Weight Sync Sandbox + +A minimal test environment focused exclusively on testing the weight synchronization +mechanism between RLTrainer and Generator. + +This sandbox: +- Initializes both trainer and generator with the same model +- Runs ONE training step to create a weight delta +- Tests push_weights() performance and correctness +- Tests update_weights() performance and correctness +- Verifies the sync with a forward pass + +Usage: + python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml +""" + +import asyncio +import os +import time + +import torch +import torchstore as ts +from forge.actors._torchstore_utils import rdma_enabled +from forge.actors.generator import Generator +from forge.actors.trainer import RLTrainer +from forge.controller.provisioner import init_provisioner, shutdown +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + +# Suppress resource_tracker warnings about shared memory cleanup +# These occur because shared memory is cleaned up by one process while the +# resource tracker in another process tries to clean it up again. The cleanup +# is working correctly - these are just harmless race condition warnings. +os.environ["PYTHONWARNINGS"] = "ignore::UserWarning:multiprocessing.resource_tracker" + + +def simple_grpo_loss( + logits: torch.Tensor, + response: torch.Tensor, + ref_logprobs: torch.Tensor, + advantages: torch.Tensor, + padding_mask: torch.Tensor, + beta: float = 0.1, +) -> torch.Tensor: + """ + Simplified loss function for weight sync testing. + Just performs basic tensor operations to create a weight delta. + """ + # Extract dimensions + local_batch_size, response_len = response.shape + vocab_size = logits.size(-1) + full_seq_len = logits.size(1) + + # Extract only the response portion from logits + request_len = full_seq_len - response_len + response_logits = logits[:, request_len:, :] + + # Flatten logits and response for cross-entropy + logits_flat = response_logits.reshape(-1, vocab_size) + response_flat = response.reshape(-1) + + # Basic cross-entropy loss + loss = torch.nn.functional.cross_entropy( + logits_flat, response_flat, reduction="none" + ).view(local_batch_size, response_len) + + # Apply padding mask and reduce + masked_loss = loss * padding_mask + loss = masked_loss.sum() / padding_mask.sum().clamp(min=1.0) + + return loss + + +def generate_random_batch( + local_batch_size: int, + request_len: int, + response_len: int, + vocab_size: int = 32000, + device: str = "cuda", + dp_size: int = 1, +): + """ + Generate random input and target tensors for a single training step. + Creates one batch per data parallel rank. + """ + inputs = [] + targets = [] + + # Create one batch for each data parallel rank + for _ in range(dp_size): + request = torch.randint( + 1, + vocab_size, + (local_batch_size, request_len), + dtype=torch.long, + device=device, + ) + response = torch.randint( + 1, + vocab_size, + (local_batch_size, response_len), + dtype=torch.long, + device=device, + ) + + # Create padding mask + padding_mask = torch.rand((local_batch_size, response_len), device=device) > 0.1 + + ref_logprobs = ( + -torch.abs(torch.randn((local_batch_size, response_len), device=device)) + - 1.0 + ) + advantages = torch.randn((local_batch_size, 1), device=device) + input_tokens = torch.cat([request, response], dim=1) + inputs.append({"tokens": input_tokens}) + targets.append( + { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": padding_mask, + } + ) + + return inputs, targets + + +async def main(cfg: DictConfig): + """ + Weight sync sandbox main function. + + Tests the complete weight synchronization pipeline: + 1. Initialize trainer and generator + 2. Run one training step + 3. Push weights from trainer + 4. Update weights in generator + 5. Verify with forward pass + """ + + # Extract configuration + local_batch_size = cfg.get("local_batch_size", None) + assert local_batch_size is not None, "local_batch_size must be specified" + + request_len = cfg.get("max_req_tokens", 64) + response_len = cfg.get("max_res_tokens", 64) + model_name = cfg.get("model") + + # Get vocab size from tokenizer + print(f"Loading tokenizer for model: {model_name}") + tokenizer = get_tokenizer(model_name) + vocab_size = tokenizer.vocab_size + print(f"Detected vocab size: {vocab_size}") + + # Get data parallel size + dp_size = cfg.get("replay_buffer", {}).get("dp_size", 1) + if dp_size is None: + trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) + dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 + + # ---- Global setups ---- # + provisioner = None + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + + metric_logging_cfg = cfg.get("metric_logging", {}) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + + # Initialize torchstore + await ts.initialize(strategy=ts.ControllerStorageVolumes()) + + print("\n" + "=" * 80) + print("WEIGHT SYNC SANDBOX") + print("=" * 80) + print(f"Model: {model_name}") + print(f"Local batch size: {local_batch_size}") + print( + f"Sequence length: {request_len + response_len} ({request_len} + {response_len})" + ) + print(f"Data parallel size: {dp_size}") + print(f"RDMA available: {rdma_enabled()}") + print(f"Sync mode: {'Direct (RDMA)' if rdma_enabled() else 'DCP (Filesystem)'}") + print("=" * 80 + "\n") + + # Initialize trainer and generator + print("Initializing trainer and generator...") + init_start = time.time() + + trainer, policy = await asyncio.gather( + RLTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, loss=simple_grpo_loss + ), + Generator.options(**cfg.actors.policy).as_actor(**cfg.policy), + ) + + init_time = time.time() - init_start + print(f"✓ Initialization complete ({init_time:.2f}s)\n") + + # Run one training step to create weight delta + print("[1/4] Running single training step to create weight delta...") + step_start = time.time() + + inputs, targets = generate_random_batch( + local_batch_size=local_batch_size, + request_len=request_len, + response_len=response_len, + vocab_size=vocab_size, + dp_size=dp_size, + ) + + await trainer.train_step.call(inputs, targets) + step_time = time.time() - step_start + print(f"✓ Training step complete ({step_time:.2f}s)\n") + + # Test push_weights + print("[2/4] Testing push_weights() to torchstore...") + push_start = time.time() + + await trainer.push_weights.call(policy_version=1) + + push_time = time.time() - push_start + print(f"✓ Pushed weights to torchstore ({push_time:.2f}s)\n") + + # Test update_weights + print("[3/4] Testing update_weights() from torchstore...") + update_start = time.time() + + await policy.update_weights.call(version=1) + + update_time = time.time() - update_start + print(f"✓ Updated weights in generator ({update_time:.2f}s)\n") + + # Verify with forward pass + print("[4/4] Verification: Running forward pass with updated weights...") + verify_start = time.time() + + test_prompt = "Write a short poem" + result = await policy.generate.call(prompt=test_prompt) + # Unwrap the ValueMesh + _, result = next(result.items()) + + verify_time = time.time() - verify_start + print(f"✓ Forward pass successful ({verify_time:.2f}s)") + print("\nSample output:") + print(f"Prompt: {test_prompt}") + print(f"Response: {result[0].text[:100]}...\n") + + # Summary + print("=" * 80) + print("WEIGHT SYNC TEST COMPLETE") + print("=" * 80) + print(f"Push time: {push_time:.2f}s") + print(f"Update time: {update_time:.2f}s") + print(f"Total sync time: {push_time + update_time:.2f}s") + print( + f"Sync mode used: {'Direct (RDMA)' if rdma_enabled() else 'DCP (Filesystem)'}" + ) + print("=" * 80 + "\n") + + # Cleanup + print("Shutting down...") + await shutdown() + print("✓ Shutdown complete.") + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() diff --git a/tests/sandbox/weight_sync/qwen3_1_7b.yaml b/tests/sandbox/weight_sync/qwen3_1_7b.yaml new file mode 100644 index 000000000..fd34bd45e --- /dev/null +++ b/tests/sandbox/weight_sync/qwen3_1_7b.yaml @@ -0,0 +1,82 @@ +# Weight Sync Sandbox Configuration +# >>> python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml + +# Global configuration +model: "Qwen/Qwen3-1.7B" +local_batch_size: 4 # Small batch - we only need ONE training step +max_req_tokens: 64 # Short sequences for fast testing +max_res_tokens: 64 + +# Observability configuration +metric_logging: + console: + logging_mode: global_reduce + +# Generator configuration (as actor, not service) +policy: + engine_args: + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: true # Faster initialization + sampling_params: + n: 1 + max_tokens: 32 # Just for verification forward pass + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration (minimal setup) +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${local_batch_size} + seq_len: 128 # max_req_tokens + max_res_tokens + max_norm: 1.0 + steps: 1 # We only run 1 step + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 # Single GPU, no FSDP + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: ./checkpoint + initial_load_path: hf://${model} + initial_load_in_hf: true + last_save_in_hf: true + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration (needed by trainer) +replay_buffer: + batch_size: ${local_batch_size} + dp_size: 1 # Must match trainer's data_parallel_shard_degree + +# Resource allocation - both as actors +actors: + policy: + procs: 1 # Single process for generator + with_gpus: true + mesh_name: policy + trainer: + procs: 1 # Single process for trainer + with_gpus: true + mesh_name: trainer From 6ee6cd6017db06d204b4a2df255c016121b098fb Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Wed, 5 Nov 2025 14:53:23 -0800 Subject: [PATCH 2/6] some cleanups --- tests/sandbox/weight_sync/README.md | 105 ---------------------- tests/sandbox/weight_sync/main.py | 104 ++++----------------- tests/sandbox/weight_sync/qwen3_1_7b.yaml | 16 +--- 3 files changed, 21 insertions(+), 204 deletions(-) delete mode 100644 tests/sandbox/weight_sync/README.md diff --git a/tests/sandbox/weight_sync/README.md b/tests/sandbox/weight_sync/README.md deleted file mode 100644 index 5f3bfa923..000000000 --- a/tests/sandbox/weight_sync/README.md +++ /dev/null @@ -1,105 +0,0 @@ -# Weight Sync Sandbox - -A minimal test environment focused exclusively on testing the weight synchronization mechanism between `RLTrainer` and `Generator`. - -## Purpose - -This sandbox tests the complete weight sync pipeline in isolation, without the complexity of a full RL training loop. It's designed for: - -- **Debugging weight sync issues**: Isolate and test push/pull mechanisms -- **Performance profiling**: Measure push_weights() and update_weights() latency -- **Mode verification**: Test both DCP (filesystem) and Direct (RDMA) sync modes -- **Quick validation**: Verify weight sync works before running full training - -## What It Tests - -### Push Weights (`RLTrainer.push_weights()`) -- Converts TorchTitan state dict to HuggingFace format -- Saves to torchstore via DCP (no RDMA) or direct (with RDMA) -- Measures push performance - -### Update Weights (`Generator.update_weights()`) -- Stops accepting new requests (lock mechanism) -- Waits for pending requests to complete -- Fetches weights from torchstore -- Loads weights into vLLM model -- Resumes accepting requests -- Measures update performance - -### Verification -- Runs forward pass with updated weights to confirm success - -## Usage - -```bash -python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml -``` - -## Test Flow - -1. **Initialize** trainer and generator with same model (Qwen3-1.7B) -2. **Run ONE training step** to create a weight delta -3. **Push weights** from trainer to torchstore (version 1) -4. **Update weights** in generator from torchstore (version 1) -5. **Verify** with forward pass to confirm weights loaded correctly - -## Output - -The sandbox prints: -- Sync mode used (DCP vs Direct/RDMA) -- Time for push_weights() -- Time for update_weights() -- Total sync time -- Sample generation output for verification - -Example output: -``` -================================================================================ -WEIGHT SYNC SANDBOX -================================================================================ -Model: Qwen/Qwen3-1.7B -RDMA available: False -Sync mode: DCP (Filesystem) -================================================================================ - -✓ Initialization complete (12.34s) - -[1/4] Running single training step to create weight delta... -✓ Training step complete (0.56s) - -[2/4] Testing push_weights() to torchstore... -✓ Pushed weights to torchstore (2.13s) - -[3/4] Testing update_weights() from torchstore... -✓ Updated weights in generator (1.87s) - -[4/4] Verification: Running forward pass with updated weights... -✓ Forward pass successful (0.23s) - -================================================================================ -WEIGHT SYNC TEST COMPLETE -================================================================================ -Push time: 2.13s -Update time: 1.87s -Total sync time: 4.00s -Sync mode used: DCP (Filesystem) -================================================================================ -``` - -## Configuration - -Uses Qwen3-1.7B for fast testing with minimal resource requirements: -- **Model**: Qwen/Qwen3-1.7B (small, fast to load) -- **Batch size**: 4 (minimal overhead) -- **Sequence length**: 128 tokens (64 request + 64 response) -- **Generator**: Single process actor (not service) -- **Trainer**: Single process, no FSDP -- **Data parallel**: 1 (single GPU) - -## Key Differences from Other Sandboxes - -- **vs vllm sandbox**: Adds trainer and tests weight updates -- **vs rl_trainer sandbox**: Adds generator and tests weight loading -- **vs toy_rl sandbox**: Focuses purely on weight sync, not full RL loop - -This is the **only sandbox that tests the complete weight sync mechanism** in isolation. diff --git a/tests/sandbox/weight_sync/main.py b/tests/sandbox/weight_sync/main.py index de02d6b9e..df6860697 100644 --- a/tests/sandbox/weight_sync/main.py +++ b/tests/sandbox/weight_sync/main.py @@ -10,19 +10,11 @@ A minimal test environment focused exclusively on testing the weight synchronization mechanism between RLTrainer and Generator. -This sandbox: -- Initializes both trainer and generator with the same model -- Runs ONE training step to create a weight delta -- Tests push_weights() performance and correctness -- Tests update_weights() performance and correctness -- Verifies the sync with a forward pass - Usage: python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml """ import asyncio -import os import time import torch @@ -37,49 +29,6 @@ from omegaconf import DictConfig from vllm.transformers_utils.tokenizer import get_tokenizer -# Suppress resource_tracker warnings about shared memory cleanup -# These occur because shared memory is cleaned up by one process while the -# resource tracker in another process tries to clean it up again. The cleanup -# is working correctly - these are just harmless race condition warnings. -os.environ["PYTHONWARNINGS"] = "ignore::UserWarning:multiprocessing.resource_tracker" - - -def simple_grpo_loss( - logits: torch.Tensor, - response: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - padding_mask: torch.Tensor, - beta: float = 0.1, -) -> torch.Tensor: - """ - Simplified loss function for weight sync testing. - Just performs basic tensor operations to create a weight delta. - """ - # Extract dimensions - local_batch_size, response_len = response.shape - vocab_size = logits.size(-1) - full_seq_len = logits.size(1) - - # Extract only the response portion from logits - request_len = full_seq_len - response_len - response_logits = logits[:, request_len:, :] - - # Flatten logits and response for cross-entropy - logits_flat = response_logits.reshape(-1, vocab_size) - response_flat = response.reshape(-1) - - # Basic cross-entropy loss - loss = torch.nn.functional.cross_entropy( - logits_flat, response_flat, reduction="none" - ).view(local_batch_size, response_len) - - # Apply padding mask and reduce - masked_loss = loss * padding_mask - loss = masked_loss.sum() / padding_mask.sum().clamp(min=1.0) - - return loss - def generate_random_batch( local_batch_size: int, @@ -136,18 +85,6 @@ def generate_random_batch( async def main(cfg: DictConfig): - """ - Weight sync sandbox main function. - - Tests the complete weight synchronization pipeline: - 1. Initialize trainer and generator - 2. Run one training step - 3. Push weights from trainer - 4. Update weights in generator - 5. Verify with forward pass - """ - - # Extract configuration local_batch_size = cfg.get("local_batch_size", None) assert local_batch_size is not None, "local_batch_size must be specified" @@ -155,17 +92,13 @@ async def main(cfg: DictConfig): response_len = cfg.get("max_res_tokens", 64) model_name = cfg.get("model") - # Get vocab size from tokenizer print(f"Loading tokenizer for model: {model_name}") tokenizer = get_tokenizer(model_name) vocab_size = tokenizer.vocab_size print(f"Detected vocab size: {vocab_size}") - # Get data parallel size - dp_size = cfg.get("replay_buffer", {}).get("dp_size", 1) - if dp_size is None: - trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) - dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 + trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) + dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 # ---- Global setups ---- # provisioner = None @@ -183,8 +116,6 @@ async def main(cfg: DictConfig): # Initialize torchstore await ts.initialize(strategy=ts.ControllerStorageVolumes()) - print("\n" + "=" * 80) - print("WEIGHT SYNC SANDBOX") print("=" * 80) print(f"Model: {model_name}") print(f"Local batch size: {local_batch_size}") @@ -192,8 +123,7 @@ async def main(cfg: DictConfig): f"Sequence length: {request_len + response_len} ({request_len} + {response_len})" ) print(f"Data parallel size: {dp_size}") - print(f"RDMA available: {rdma_enabled()}") - print(f"Sync mode: {'Direct (RDMA)' if rdma_enabled() else 'DCP (Filesystem)'}") + print(f"Is RDMA available? {rdma_enabled()}") print("=" * 80 + "\n") # Initialize trainer and generator @@ -202,16 +132,19 @@ async def main(cfg: DictConfig): trainer, policy = await asyncio.gather( RLTrainer.options(**cfg.actors.trainer).as_actor( - **cfg.trainer, loss=simple_grpo_loss + **cfg.trainer, + loss=lambda *args, **kwargs: torch.tensor( + 1.0, requires_grad=True, device="cuda" + ), ), Generator.options(**cfg.actors.policy).as_actor(**cfg.policy), ) init_time = time.time() - init_start - print(f"✓ Initialization complete ({init_time:.2f}s)\n") + print(f"Finished initialization in ({init_time:.2f}s)") # Run one training step to create weight delta - print("[1/4] Running single training step to create weight delta...") + print("Running single training step.....") step_start = time.time() inputs, targets = generate_random_batch( @@ -224,28 +157,28 @@ async def main(cfg: DictConfig): await trainer.train_step.call(inputs, targets) step_time = time.time() - step_start - print(f"✓ Training step complete ({step_time:.2f}s)\n") + print(f"Finished train step in ({step_time:.2f}s)\n") # Test push_weights - print("[2/4] Testing push_weights() to torchstore...") + print("Pushing weights from trainer to store...") push_start = time.time() await trainer.push_weights.call(policy_version=1) push_time = time.time() - push_start - print(f"✓ Pushed weights to torchstore ({push_time:.2f}s)\n") + print(f"Finished weights push in ({push_time:.2f}s)\n") # Test update_weights - print("[3/4] Testing update_weights() from torchstore...") + print("Updating generator weights from store...") update_start = time.time() await policy.update_weights.call(version=1) update_time = time.time() - update_start - print(f"✓ Updated weights in generator ({update_time:.2f}s)\n") + print(f"Updated generator weights ({update_time:.2f}s)\n") # Verify with forward pass - print("[4/4] Verification: Running forward pass with updated weights...") + print("Verifying forward pass with updated weights...") verify_start = time.time() test_prompt = "Write a short poem" @@ -254,21 +187,18 @@ async def main(cfg: DictConfig): _, result = next(result.items()) verify_time = time.time() - verify_start - print(f"✓ Forward pass successful ({verify_time:.2f}s)") + print(f"Finished testing forward pass in ({verify_time:.2f}s)") print("\nSample output:") print(f"Prompt: {test_prompt}") print(f"Response: {result[0].text[:100]}...\n") # Summary print("=" * 80) - print("WEIGHT SYNC TEST COMPLETE") + print("Results") print("=" * 80) print(f"Push time: {push_time:.2f}s") print(f"Update time: {update_time:.2f}s") print(f"Total sync time: {push_time + update_time:.2f}s") - print( - f"Sync mode used: {'Direct (RDMA)' if rdma_enabled() else 'DCP (Filesystem)'}" - ) print("=" * 80 + "\n") # Cleanup diff --git a/tests/sandbox/weight_sync/qwen3_1_7b.yaml b/tests/sandbox/weight_sync/qwen3_1_7b.yaml index fd34bd45e..ea28a1471 100644 --- a/tests/sandbox/weight_sync/qwen3_1_7b.yaml +++ b/tests/sandbox/weight_sync/qwen3_1_7b.yaml @@ -1,31 +1,28 @@ # Weight Sync Sandbox Configuration # >>> python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml -# Global configuration model: "Qwen/Qwen3-1.7B" -local_batch_size: 4 # Small batch - we only need ONE training step -max_req_tokens: 64 # Short sequences for fast testing +local_batch_size: 4 +max_req_tokens: 64 max_res_tokens: 64 -# Observability configuration metric_logging: console: logging_mode: global_reduce -# Generator configuration (as actor, not service) policy: + prefetch_weights_to_shm: false # Disable to avoid shared memory warnings in test engine_args: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 - enforce_eager: true # Faster initialization + enforce_eager: true sampling_params: n: 1 max_tokens: 32 # Just for verification forward pass temperature: 1.0 top_p: 1.0 -# Trainer configuration (minimal setup) trainer: model: name: qwen3 @@ -65,11 +62,6 @@ trainer: mode: selective selective_ac_option: op -# Replay buffer configuration (needed by trainer) -replay_buffer: - batch_size: ${local_batch_size} - dp_size: 1 # Must match trainer's data_parallel_shard_degree - # Resource allocation - both as actors actors: policy: From 7fc533a67ca41d5e40cdbfd8f13636bd7b920548 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Wed, 5 Nov 2025 14:59:01 -0800 Subject: [PATCH 3/6] need to do some more stuff --- tests/sandbox/weight_sync/main.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/sandbox/weight_sync/main.py b/tests/sandbox/weight_sync/main.py index df6860697..ac6078c48 100644 --- a/tests/sandbox/weight_sync/main.py +++ b/tests/sandbox/weight_sync/main.py @@ -177,20 +177,33 @@ async def main(cfg: DictConfig): update_time = time.time() - update_start print(f"Updated generator weights ({update_time:.2f}s)\n") - # Verify with forward pass - print("Verifying forward pass with updated weights...") + # Verify with forward pass - compare trainer and generator outputs + print("Verifying weight sync by comparing trainer and generator outputs...") verify_start = time.time() + # Create a simple test input + test_input_ids = torch.randint( + 1, vocab_size, (1, 32), dtype=torch.long, device="cuda" + ) + + # Get logits from trainer + # TODO - let's add in extended trainer/generator actor implementations here + # where we can add new endpoints + # We probably also will add these endpoints anyways + + # Get logits from generator via generation test_prompt = "Write a short poem" result = await policy.generate.call(prompt=test_prompt) # Unwrap the ValueMesh _, result = next(result.items()) verify_time = time.time() - verify_start - print(f"Finished testing forward pass in ({verify_time:.2f}s)") + print(f"Finished verification in ({verify_time:.2f}s)") print("\nSample output:") print(f"Prompt: {test_prompt}") - print(f"Response: {result[0].text[:100]}...\n") + print(f"Response: {result[0].text[:100]}...") + print("Note: Full logit comparison requires exposing trainer inference endpoint") + print() # Summary print("=" * 80) From bf83bf4db473fda0042220ef68794c959af52627 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 08:55:47 -0800 Subject: [PATCH 4/6] slight edit --- tests/sandbox/weight_sync/main.py | 32 ++----------------------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/tests/sandbox/weight_sync/main.py b/tests/sandbox/weight_sync/main.py index ac6078c48..bd0052f05 100644 --- a/tests/sandbox/weight_sync/main.py +++ b/tests/sandbox/weight_sync/main.py @@ -144,7 +144,7 @@ async def main(cfg: DictConfig): print(f"Finished initialization in ({init_time:.2f}s)") # Run one training step to create weight delta - print("Running single training step.....") + print("Running single training step...") step_start = time.time() inputs, targets = generate_random_batch( @@ -177,34 +177,6 @@ async def main(cfg: DictConfig): update_time = time.time() - update_start print(f"Updated generator weights ({update_time:.2f}s)\n") - # Verify with forward pass - compare trainer and generator outputs - print("Verifying weight sync by comparing trainer and generator outputs...") - verify_start = time.time() - - # Create a simple test input - test_input_ids = torch.randint( - 1, vocab_size, (1, 32), dtype=torch.long, device="cuda" - ) - - # Get logits from trainer - # TODO - let's add in extended trainer/generator actor implementations here - # where we can add new endpoints - # We probably also will add these endpoints anyways - - # Get logits from generator via generation - test_prompt = "Write a short poem" - result = await policy.generate.call(prompt=test_prompt) - # Unwrap the ValueMesh - _, result = next(result.items()) - - verify_time = time.time() - verify_start - print(f"Finished verification in ({verify_time:.2f}s)") - print("\nSample output:") - print(f"Prompt: {test_prompt}") - print(f"Response: {result[0].text[:100]}...") - print("Note: Full logit comparison requires exposing trainer inference endpoint") - print() - # Summary print("=" * 80) print("Results") @@ -217,7 +189,7 @@ async def main(cfg: DictConfig): # Cleanup print("Shutting down...") await shutdown() - print("✓ Shutdown complete.") + print("Shutdown complete.") if __name__ == "__main__": From b077a6564cab48a6b59f2e5e9c444ac9fe523be8 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 08:56:55 -0800 Subject: [PATCH 5/6] comment --- tests/sandbox/weight_sync/main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sandbox/weight_sync/main.py b/tests/sandbox/weight_sync/main.py index bd0052f05..dfdc58f0a 100644 --- a/tests/sandbox/weight_sync/main.py +++ b/tests/sandbox/weight_sync/main.py @@ -177,6 +177,10 @@ async def main(cfg: DictConfig): update_time = time.time() - update_start print(f"Updated generator weights ({update_time:.2f}s)\n") + # TODO - ideally we have the capability to check forward passes between + # the trainer/generator to verify correctness. This would require adding + # forward capabilities to both trainer/generator actors. + # Summary print("=" * 80) print("Results") From 252b2b958553ba069747975dbebd301c122bfc0d Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Thu, 6 Nov 2025 08:58:05 -0800 Subject: [PATCH 6/6] no claude code --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index d5a139050..c952405d6 100644 --- a/.gitignore +++ b/.gitignore @@ -204,6 +204,3 @@ demo_top_down.md # enroot / sqsh *.sqsh - -# claude magic -.claude