From 10d336a33133e6d49974a6e07851f9f7f2f7a3d1 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 10:26:40 -0700 Subject: [PATCH 01/30] shared tensor util --- src/forge/util/_shared_tensor.py | 206 +++++++ tests/unit_tests/util/test_shared_tensor.py | 580 ++++++++++++++++++++ 2 files changed, 786 insertions(+) create mode 100644 src/forge/util/_shared_tensor.py create mode 100644 tests/unit_tests/util/test_shared_tensor.py diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py new file mode 100644 index 000000000..94d4d5350 --- /dev/null +++ b/src/forge/util/_shared_tensor.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import uuid +from multiprocessing import shared_memory +from typing import Tuple, Union + +import numpy as np +import torch + + +class SharedTensor: + """Wrapper class for tensors backed my shared memory""" + + def __init__(self, tensor=None, handle=None): + if tensor is not None: + self._create_from_tensor(tensor) + elif handle is not None: + self._create_from_handle(handle) + else: + raise ValueError("Must provide either tensor or handle") + + @classmethod + def empty( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create an empty tensor directly in shared memory (no copy/allocation overhead) + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype (supports bfloat16, float32, etc.) + + Returns: + SharedTensor instance with uninitialized data + """ + instance = cls.__new__(cls) + instance._create_empty(shape, dtype) + return instance + + @classmethod + def zeros( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a zero-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with zeros + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.zero_() + return shared_tensor + + @classmethod + def ones( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a ones-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with ones + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.fill_(1) + return shared_tensor + + def _create_empty(self, shape, dtype): + """Initialize with empty tensor in shared memory""" + # Store metadata + self.shape = tuple(shape) if not isinstance(shape, tuple) else shape + self.dtype = dtype + self.dtype_str = str(dtype) + + # Calculate size + element_size = torch.tensor([], dtype=dtype).element_size() + total_elements = int(np.prod(self.shape)) + byte_size = total_elements * element_size + + # Create shared memory (uninitialized - fast!) + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + self.shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self.shm_name = shm_name + + def _create_from_tensor(self, tensor): + """Initialize from an existing tensor""" + tensor = tensor.contiguous() + + # Store metadata + self.shape = tuple(tensor.shape) + self.dtype = tensor.dtype + self.dtype_str = str(tensor.dtype) + + # Create shared memory + byte_size = tensor.numel() * tensor.element_size() + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + + self.shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self.shm_name = shm_name + + # Copy data as raw bytes + raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy() + self.shm.buf[:byte_size] = raw_bytes + + def _create_from_handle(self, handle): + """Initialize from a handle""" + self.shm_name = handle["shm_name"] + self.shape = handle["shape"] + self.dtype_str = handle["dtype"] + self.dtype = self._parse_dtype(self.dtype_str) + + # Attach to existing shared memory + self.shm = shared_memory.SharedMemory(name=self.shm_name) + + def _create_tensor_view(self): + """Create tensor view of shared memory.""" + element_size = torch.tensor([], dtype=self.dtype).element_size() + total_elements = int(np.prod(self.shape)) + byte_size = total_elements * element_size + + # Create numpy array that shares the buffer + np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self.shm.buf) + # Create torch tensor from numpy (shares memory) + uint8_tensor = torch.from_numpy(np_array) + tensor = uint8_tensor.view(self.dtype).reshape(self.shape) + + # Keep both the np array and the SharedTensor object alive + tensor._forge_shared_tensor = self + tensor._forge_np_array = np_array + + return tensor + + def _parse_dtype(self, dtype_str): + """Parse dtype string""" + dtype_str = dtype_str.replace("torch.", "") + return getattr(torch, dtype_str) + + def get_handle(self): + """Get picklable handle""" + return {"shm_name": self.shm_name, "shape": self.shape, "dtype": self.dtype_str} + + @functools.cached_property + def tensor(self): + """Get the underlying tensor""" + return self._create_tensor_view() + + def copy_from(self, source_tensor): + """ + Copy data from another tensor into this shared tensor + Useful when you create empty tensor first, then fill it + + Args: + source_tensor: Source tensor to copy from + """ + if source_tensor.shape != self.shape: + raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self.shape}") + # Copy data + self.tensor.copy_(source_tensor) + + def clone(self): + """Create a new SharedTensor with copied data""" + new_shared = SharedTensor.empty(self.shape, self.dtype) + new_shared.tensor.copy_(self.tensor) + return new_shared + + def cleanup(self): + """Clean up shared memory""" + try: + self.shm.close() + self.shm.unlink() + except Exception: + pass + + def __del__(self): + """Cleanup on deletion""" + if hasattr(self, "shm"): + try: + self.shm.close() + except Exception: + pass + + def __repr__(self): + return f"SharedTensor(shape={self.shape}, dtype={self.dtype}, shm_name={self.shm_name})" diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py new file mode 100644 index 000000000..ebc61c9bd --- /dev/null +++ b/tests/unit_tests/util/test_shared_tensor.py @@ -0,0 +1,580 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +import time +from multiprocessing import Process, Queue + +import pytest +import torch + +# Assuming SharedTensor is in shared_tensor.py +from forge.util._shared_tensor import SharedTensor + + +class TestSharedTensorCreation: + """Test tensor creation methods""" + + def test_empty_creation(self): + """Test creating empty tensor""" + shape = (100, 200) + dtype = torch.float32 + + shared = SharedTensor.empty(shape, dtype) + + assert shared.shape == shape + assert shared.dtype == dtype + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype + + shared.cleanup() + + def test_empty_with_bfloat16(self): + """Test creating empty bfloat16 tensor""" + shape = (50, 50) + shared = SharedTensor.empty(shape, torch.bfloat16) + + assert shared.dtype == torch.bfloat16 + assert shared.tensor.dtype == torch.bfloat16 + + shared.cleanup() + + def test_zeros_creation(self): + """Test creating zero-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.zeros(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 0) + assert tensor.sum().item() == 0.0 + + shared.cleanup() + + def test_ones_creation(self): + """Test creating ones-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.ones(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 1) + assert tensor.sum().item() == 200.0 + + shared.cleanup() + + def test_from_tensor_creation(self): + """Test creating from existing tensor""" + original = torch.randn(50, 50) + shared = SharedTensor(tensor=original) + + assert shared.shape == tuple(original.shape) + assert shared.dtype == original.dtype + assert torch.allclose(shared.tensor, original) + + shared.cleanup() + + def test_from_handle_creation(self): + """Test creating from handle""" + # Create original + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + # Get handle + handle = original.get_handle() + + # Create from handle + reconstructed = SharedTensor(handle=handle) + + assert torch.all(reconstructed.tensor == 5.0) + assert reconstructed.shape == original.shape + assert reconstructed.dtype == original.dtype + + original.cleanup() + + def test_creation_requires_argument(self): + """Test that creation without arguments raises error""" + with pytest.raises(ValueError, match="Must provide either tensor or handle"): + SharedTensor() + + @pytest.mark.parametrize( + "shape", + [ + (10,), + (10, 20), + (5, 10, 15), + (2, 3, 4, 5), + ], + ) + def test_various_shapes(self, shape): + """Test creation with various shapes""" + shared = SharedTensor.empty(shape, torch.float32) + assert shared.shape == shape + assert shared.tensor.shape == torch.Size(shape) + shared.cleanup() + + +class TestSharedTensorDtypes: + """Test all supported dtypes""" + + @pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + ], + ) + def test_all_dtypes(self, dtype): + """Test that all dtypes work correctly""" + shape = (10, 10) + shared = SharedTensor.empty(shape, dtype) + + assert shared.dtype == dtype + assert shared.tensor.dtype == dtype + + # Test that we can write to it + if dtype == torch.bool: + shared.tensor.fill_(True) + elif dtype in [torch.int32, torch.int64, torch.int16, torch.int8, torch.uint8]: + shared.tensor.fill_(42) + else: + shared.tensor.fill_(3.14) + + shared.cleanup() + + def test_dtype_conversion_in_handle(self): + """Test dtype is preserved through handle""" + for dtype in [torch.float32, torch.bfloat16, torch.int64]: + shared1 = SharedTensor.empty((5, 5), dtype) + handle = shared1.get_handle() + + assert handle["dtype"] == str(dtype) + + shared2 = SharedTensor(handle=handle) + assert shared2.dtype == dtype + + shared1.cleanup() + + +class TestSharedTensorOperations: + """Test tensor operations""" + + def test_copy_from(self): + """Test copying data from another tensor""" + source = torch.randn(20, 30) + shared = SharedTensor.empty((20, 30), torch.float32) + + shared.copy_from(source) + + assert torch.allclose(shared.tensor, source) + shared.cleanup() + + def test_copy_from_shape_mismatch(self): + """Test copy_from raises error on shape mismatch""" + source = torch.randn(10, 10) + shared = SharedTensor.empty((20, 20), torch.float32) + + with pytest.raises(ValueError, match="Shape mismatch"): + shared.copy_from(source) + + shared.cleanup() + + def test_clone(self): + """Test cloning creates independent copy""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + cloned = original.clone() + + # Verify data is same + assert torch.all(cloned.tensor == 5.0) + + # Verify they're independent + original.tensor.fill_(10.0) + assert torch.all(cloned.tensor == 5.0) + assert torch.all(original.tensor == 10.0) + + original.cleanup() + cloned.cleanup() + + def test_tensor_modifications(self): + """Test that modifications to tensor are reflected""" + shared = SharedTensor.zeros((10, 10), torch.float32) + tensor = shared.tensor + + tensor[0, 0] = 99.0 + tensor[5:, :] = 42.0 + + # Get tensor again and verify changes persist + tensor2 = shared.tensor + assert tensor2[0, 0].item() == 99.0 + assert torch.all(tensor2[5:, :] == 42.0) + + shared.cleanup() + + def test_inplace_operations(self): + """Test in-place operations work""" + shared = SharedTensor.empty((100, 100), torch.float32) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.mean().item() + + tensor.add_(5.0) + new_mean = tensor.mean().item() + + assert abs(new_mean - (mean + 5.0)) < 0.1 + + shared.cleanup() + + +class TestSharedTensorSerialization: + """Test pickling and handle serialization""" + + def test_handle_is_picklable(self): + """Test that handle can be pickled""" + shared = SharedTensor.empty((10, 10), torch.float32) + handle = shared.get_handle() + + # Pickle and unpickle + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + assert unpickled_handle == handle + assert unpickled_handle["shm_name"] == handle["shm_name"] + assert unpickled_handle["shape"] == handle["shape"] + assert unpickled_handle["dtype"] == handle["dtype"] + + shared.cleanup() + + def test_handle_small_size(self): + """Test that handle is small (efficient for RPC)""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + pickled = pickle.dumps(handle) + + # Handle should be < 1KB even for huge tensors + assert len(pickled) < 1024 + + shared.cleanup() + + def test_data_integrity_after_pickle(self): + """Test data is preserved through handle pickling""" + # Create and fill tensor + shared1 = SharedTensor.empty((50, 50), torch.bfloat16) + shared1.tensor.normal_(0, 1) + original_data = shared1.tensor.clone() + + # Pickle handle + handle = shared1.get_handle() + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + # Reconstruct + shared2 = SharedTensor(handle=unpickled_handle) + + # Verify data is same + assert torch.allclose(shared2.tensor.float(), original_data.float(), rtol=1e-3) + + shared1.cleanup() + + +class TestSharedTensorMemory: + """Test memory management and cleanup""" + + def test_cleanup(self): + """Test cleanup removes shared memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + shm_name = shared.shm_name + + # Verify shared memory exists + tensor = shared.tensor + tensor.fill_(5.0) + + # Cleanup + shared.cleanup() + + # Trying to attach should fail + from multiprocessing import shared_memory + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=shm_name) + + def test_multiple_views_same_memory(self): + """Test multiple tensor views point to same memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + + tensor1 = shared.tensor + tensor1.fill_(5.0) + + tensor2 = shared.tensor + assert torch.all(tensor2 == 5.0) + + # Modify through tensor2 + tensor2.fill_(10.0) + + # Verify tensor1 sees the change + assert torch.all(tensor1 == 10.0) + + shared.cleanup() + + def test_handle_reconstruction_shares_memory(self): + """Test that handle reconstruction shares same memory""" + shared1 = SharedTensor.empty((20, 20), torch.float32) + shared1.tensor.fill_(7.0) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + # Modify through shared2 + shared2.tensor.fill_(14.0) + + # Verify shared1 sees the change + assert torch.all(shared1.tensor == 14.0) + + shared1.cleanup() + + +class TestSharedTensorEdgeCases: + """Test edge cases and error conditions""" + + def test_empty_shape(self): + """Test scalar tensor (empty shape)""" + shared = SharedTensor.empty((), torch.float32) + assert shared.shape == () + assert shared.tensor.numel() == 1 + shared.cleanup() + + def test_single_element_tensor(self): + """Test 1-element tensor""" + shared = SharedTensor.empty((1,), torch.float32) + shared.tensor.fill_(42.0) + assert shared.tensor.item() == 42.0 + shared.cleanup() + + def test_large_tensor(self): + """Test large tensor (1GB)""" + # 1GB tensor: 250M float32 elements + shape = (250_000_000,) + shared = SharedTensor.empty(shape, torch.float32) + + assert shared.shape == shape + assert shared.tensor.numel() == 250_000_000 + + shared.cleanup() + + def test_non_contiguous_tensor_conversion(self): + """Test that non-contiguous tensors are handled""" + # Create non-contiguous tensor + original = torch.randn(10, 10).t() # Transpose makes it non-contiguous + assert not original.is_contiguous() + + # Should work (internally makes contiguous) + shared = SharedTensor(tensor=original) + + # Result should match + assert torch.allclose(shared.tensor, original) + + shared.cleanup() + + def test_repr(self): + """Test string representation""" + shared = SharedTensor.empty((10, 20), torch.float32) + repr_str = repr(shared) + + assert "SharedTensor" in repr_str + assert "10, 20" in repr_str + assert "float32" in repr_str + assert shared.shm_name in repr_str + + shared.cleanup() + + +class TestSharedTensorMultiprocess: + """Test multiprocess scenarios""" + + def test_multiprocess_read(self): + """Test reading shared tensor from another process""" + + def reader_process(handle_dict, result_queue): + shared = SharedTensor(handle=handle_dict) + tensor = shared.tensor + result_queue.put(tensor.sum().item()) + + # Create shared tensor in main process + shared = SharedTensor.empty((100, 100), torch.float32) + shared.tensor.fill_(5.0) + + # Read from child process + result_queue = Queue() + handle = shared.get_handle() + + p = Process(target=reader_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 5.0 * 100 * 100 + + assert abs(result - expected) < 1e-5 + + shared.cleanup() + + def test_multiprocess_write(self): + """Test writing to shared tensor from another process""" + + def writer_process(handle_dict, value): + shared = SharedTensor(handle=handle_dict) + shared.tensor.fill_(value) + + # Create empty shared tensor + shared = SharedTensor.empty((50, 50), torch.float32) + shared.tensor.zero_() + + # Write from child process + handle = shared.get_handle() + + p = Process(target=writer_process, args=(handle, 42.0)) + p.start() + p.join() + + # Verify in main process + assert torch.all(shared.tensor == 42.0) + + shared.cleanup() + + def test_multiprocess_bidirectional(self): + """Test bidirectional communication""" + + def worker_process(input_handle, output_handle): + input_tensor = SharedTensor(handle=input_handle).tensor + output_tensor = SharedTensor(handle=output_handle).tensor + + # Compute: output = input * 2 + output_tensor.copy_(input_tensor * 2) + + # Create input and output tensors + input_shared = SharedTensor.empty((100, 100), torch.float32) + input_shared.tensor.normal_(0, 1) + input_data = input_shared.tensor.clone() + + output_shared = SharedTensor.empty((100, 100), torch.float32) + + # Process in child + p = Process( + target=worker_process, + args=(input_shared.get_handle(), output_shared.get_handle()), + ) + p.start() + p.join() + + # Verify result + expected = input_data * 2 + assert torch.allclose( + output_shared.tensor, expected + ), "output: {}, expected: {}".format(output_shared.tensor, expected) + + input_shared.cleanup() + output_shared.cleanup() + + +class TestSharedTensorPerformance: + """Performance-related tests""" + + def test_empty_faster_than_from_tensor(self): + """Test that empty() is faster than from tensor""" + shape = (1000, 1000) + + # Time empty creation + start = time.time() + for _ in range(10): + shared = SharedTensor.empty(shape, torch.float32) + shared.cleanup() + empty_time = time.time() - start + + # Time from_tensor creation + start = time.time() + for _ in range(10): + tensor = torch.randn(shape) + shared = SharedTensor(tensor=tensor) + shared.cleanup() + from_tensor_time = time.time() - start + + # empty() should be faster (no data copying) + assert empty_time < from_tensor_time + + def test_handle_serialization_fast(self): + """Test that handle serialization is fast""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + start = time.time() + for _ in range(1000): + pickled = pickle.dumps(handle) + unpickled = pickle.loads(pickled) + elapsed = time.time() - start + + # Should be able to do 1000 round trips in < 0.1 seconds + assert elapsed < 0.1 + + shared.cleanup() + + +class TestSharedTensorBfloat16: + """Specific tests for bfloat16 support""" + + def test_bfloat16_creation(self): + """Test bfloat16 tensor creation""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + assert shared.dtype == torch.bfloat16 + shared.cleanup() + + def test_bfloat16_from_tensor(self): + """Test creating shared tensor from bfloat16 tensor""" + original = torch.randn(50, 50, dtype=torch.bfloat16) + shared = SharedTensor(tensor=original) + + assert shared.dtype == torch.bfloat16 + assert torch.allclose(shared.tensor.float(), original.float(), rtol=1e-3) + + shared.cleanup() + + def test_bfloat16_handle_preservation(self): + """Test bfloat16 dtype preserved through handle""" + shared1 = SharedTensor.empty((20, 20), torch.bfloat16) + shared1.tensor.normal_(0, 1) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + assert shared2.dtype == torch.bfloat16 + assert torch.allclose(shared1.tensor.float(), shared2.tensor.float(), rtol=1e-3) + + shared1.cleanup() + + def test_bfloat16_operations(self): + """Test operations on bfloat16 tensors""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.float().mean().item() + + # Mean should be close to 0 + assert abs(mean) < 0.1 + + shared.cleanup() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) From 68395f9499d75f3375d0616d4488f02b5a3bf821 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 12:42:21 -0700 Subject: [PATCH 02/30] refactor, fix test --- src/forge/util/_shared_tensor.py | 64 +++++++++++---------- tests/unit_tests/util/test_shared_tensor.py | 44 +++++++------- 2 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py index 94d4d5350..b86cce249 100644 --- a/src/forge/util/_shared_tensor.py +++ b/src/forge/util/_shared_tensor.py @@ -16,7 +16,7 @@ class SharedTensor: """Wrapper class for tensors backed my shared memory""" - def __init__(self, tensor=None, handle=None): + def __init__(self, *, tensor=None, handle=None): if tensor is not None: self._create_from_tensor(tensor) elif handle is not None: @@ -87,65 +87,65 @@ def ones( def _create_empty(self, shape, dtype): """Initialize with empty tensor in shared memory""" # Store metadata - self.shape = tuple(shape) if not isinstance(shape, tuple) else shape - self.dtype = dtype - self.dtype_str = str(dtype) + self._shape = tuple(shape) if not isinstance(shape, tuple) else shape + self._dtype = dtype + self._dtype_str = str(dtype) # Calculate size element_size = torch.tensor([], dtype=dtype).element_size() - total_elements = int(np.prod(self.shape)) + total_elements = int(np.prod(self._shape)) byte_size = total_elements * element_size # Create shared memory (uninitialized - fast!) shm_name = f"shared_tensor_{uuid.uuid4().hex}" - self.shm = shared_memory.SharedMemory( + self._shm = shared_memory.SharedMemory( create=True, size=byte_size, name=shm_name ) - self.shm_name = shm_name + self._shm_name = shm_name def _create_from_tensor(self, tensor): """Initialize from an existing tensor""" tensor = tensor.contiguous() # Store metadata - self.shape = tuple(tensor.shape) - self.dtype = tensor.dtype - self.dtype_str = str(tensor.dtype) + self._shape = tuple(tensor.shape) + self._dtype = tensor.dtype + self._dtype_str = str(tensor.dtype) # Create shared memory byte_size = tensor.numel() * tensor.element_size() shm_name = f"shared_tensor_{uuid.uuid4().hex}" - self.shm = shared_memory.SharedMemory( + self._shm = shared_memory.SharedMemory( create=True, size=byte_size, name=shm_name ) - self.shm_name = shm_name + self._shm_name = shm_name # Copy data as raw bytes raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy() - self.shm.buf[:byte_size] = raw_bytes + self._shm.buf[:byte_size] = raw_bytes def _create_from_handle(self, handle): """Initialize from a handle""" - self.shm_name = handle["shm_name"] - self.shape = handle["shape"] - self.dtype_str = handle["dtype"] - self.dtype = self._parse_dtype(self.dtype_str) + self._shm_name = handle["shm_name"] + self._shape = handle["shape"] + self._dtype_str = handle["dtype"] + self._dtype = self._parse_dtype(self._dtype_str) # Attach to existing shared memory - self.shm = shared_memory.SharedMemory(name=self.shm_name) + self._shm = shared_memory.SharedMemory(name=self._shm_name) def _create_tensor_view(self): """Create tensor view of shared memory.""" - element_size = torch.tensor([], dtype=self.dtype).element_size() - total_elements = int(np.prod(self.shape)) + element_size = torch.tensor([], dtype=self._dtype).element_size() + total_elements = int(np.prod(self._shape)) byte_size = total_elements * element_size # Create numpy array that shares the buffer - np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self.shm.buf) + np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self._shm.buf) # Create torch tensor from numpy (shares memory) uint8_tensor = torch.from_numpy(np_array) - tensor = uint8_tensor.view(self.dtype).reshape(self.shape) + tensor = uint8_tensor.view(self._dtype).reshape(self._shape) # Keep both the np array and the SharedTensor object alive tensor._forge_shared_tensor = self @@ -160,7 +160,11 @@ def _parse_dtype(self, dtype_str): def get_handle(self): """Get picklable handle""" - return {"shm_name": self.shm_name, "shape": self.shape, "dtype": self.dtype_str} + return { + "shm_name": self._shm_name, + "shape": self._shape, + "dtype": self._dtype_str, + } @functools.cached_property def tensor(self): @@ -175,22 +179,22 @@ def copy_from(self, source_tensor): Args: source_tensor: Source tensor to copy from """ - if source_tensor.shape != self.shape: - raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self.shape}") + if source_tensor.shape != self._shape: + raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self._shape}") # Copy data self.tensor.copy_(source_tensor) def clone(self): """Create a new SharedTensor with copied data""" - new_shared = SharedTensor.empty(self.shape, self.dtype) + new_shared = SharedTensor.empty(self._shape, self._dtype) new_shared.tensor.copy_(self.tensor) return new_shared def cleanup(self): """Clean up shared memory""" try: - self.shm.close() - self.shm.unlink() + self._shm.close() + self._shm.unlink() except Exception: pass @@ -198,9 +202,9 @@ def __del__(self): """Cleanup on deletion""" if hasattr(self, "shm"): try: - self.shm.close() + self._shm.close() except Exception: pass def __repr__(self): - return f"SharedTensor(shape={self.shape}, dtype={self.dtype}, shm_name={self.shm_name})" + return f"SharedTensor(shape={self._shape}, dtype={self._dtype}, shm_name={self._shm_name})" diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py index ebc61c9bd..0383366dd 100644 --- a/tests/unit_tests/util/test_shared_tensor.py +++ b/tests/unit_tests/util/test_shared_tensor.py @@ -25,8 +25,8 @@ def test_empty_creation(self): shared = SharedTensor.empty(shape, dtype) - assert shared.shape == shape - assert shared.dtype == dtype + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype assert shared.tensor.shape == torch.Size(shape) assert shared.tensor.dtype == dtype @@ -37,7 +37,7 @@ def test_empty_with_bfloat16(self): shape = (50, 50) shared = SharedTensor.empty(shape, torch.bfloat16) - assert shared.dtype == torch.bfloat16 + assert shared.tensor.dtype == torch.bfloat16 assert shared.tensor.dtype == torch.bfloat16 shared.cleanup() @@ -69,8 +69,8 @@ def test_from_tensor_creation(self): original = torch.randn(50, 50) shared = SharedTensor(tensor=original) - assert shared.shape == tuple(original.shape) - assert shared.dtype == original.dtype + assert shared.tensor.shape == original.shape + assert shared.tensor.dtype == original.dtype assert torch.allclose(shared.tensor, original) shared.cleanup() @@ -88,8 +88,8 @@ def test_from_handle_creation(self): reconstructed = SharedTensor(handle=handle) assert torch.all(reconstructed.tensor == 5.0) - assert reconstructed.shape == original.shape - assert reconstructed.dtype == original.dtype + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype original.cleanup() @@ -110,7 +110,7 @@ def test_creation_requires_argument(self): def test_various_shapes(self, shape): """Test creation with various shapes""" shared = SharedTensor.empty(shape, torch.float32) - assert shared.shape == shape + assert shared.tensor.shape == torch.Size(shape) assert shared.tensor.shape == torch.Size(shape) shared.cleanup() @@ -138,7 +138,7 @@ def test_all_dtypes(self, dtype): shape = (10, 10) shared = SharedTensor.empty(shape, dtype) - assert shared.dtype == dtype + assert shared.tensor.dtype == dtype assert shared.tensor.dtype == dtype # Test that we can write to it @@ -157,10 +157,8 @@ def test_dtype_conversion_in_handle(self): shared1 = SharedTensor.empty((5, 5), dtype) handle = shared1.get_handle() - assert handle["dtype"] == str(dtype) - shared2 = SharedTensor(handle=handle) - assert shared2.dtype == dtype + assert shared2.tensor.dtype == dtype shared1.cleanup() @@ -295,7 +293,7 @@ class TestSharedTensorMemory: def test_cleanup(self): """Test cleanup removes shared memory""" shared = SharedTensor.empty((10, 10), torch.float32) - shm_name = shared.shm_name + shm_name = shared._shm_name # Verify shared memory exists tensor = shared.tensor @@ -350,9 +348,15 @@ class TestSharedTensorEdgeCases: def test_empty_shape(self): """Test scalar tensor (empty shape)""" - shared = SharedTensor.empty((), torch.float32) - assert shared.shape == () + shared = SharedTensor.ones((), torch.float32) + assert shared.tensor.shape == () assert shared.tensor.numel() == 1 + assert torch.allclose( + shared.tensor, + torch.ones( + (), + ), + ) shared.cleanup() def test_single_element_tensor(self): @@ -368,7 +372,7 @@ def test_large_tensor(self): shape = (250_000_000,) shared = SharedTensor.empty(shape, torch.float32) - assert shared.shape == shape + assert shared.tensor.shape == shape assert shared.tensor.numel() == 250_000_000 shared.cleanup() @@ -395,7 +399,7 @@ def test_repr(self): assert "SharedTensor" in repr_str assert "10, 20" in repr_str assert "float32" in repr_str - assert shared.shm_name in repr_str + assert shared._shm_name in repr_str shared.cleanup() @@ -536,7 +540,7 @@ class TestSharedTensorBfloat16: def test_bfloat16_creation(self): """Test bfloat16 tensor creation""" shared = SharedTensor.empty((100, 100), torch.bfloat16) - assert shared.dtype == torch.bfloat16 + assert shared.tensor.dtype == torch.bfloat16 shared.cleanup() def test_bfloat16_from_tensor(self): @@ -544,7 +548,7 @@ def test_bfloat16_from_tensor(self): original = torch.randn(50, 50, dtype=torch.bfloat16) shared = SharedTensor(tensor=original) - assert shared.dtype == torch.bfloat16 + assert shared.tensor.dtype == torch.bfloat16 assert torch.allclose(shared.tensor.float(), original.float(), rtol=1e-3) shared.cleanup() @@ -557,7 +561,7 @@ def test_bfloat16_handle_preservation(self): handle = shared1.get_handle() shared2 = SharedTensor(handle=handle) - assert shared2.dtype == torch.bfloat16 + assert shared2.tensor.dtype == torch.bfloat16 assert torch.allclose(shared1.tensor.float(), shared2.tensor.float(), rtol=1e-3) shared1.cleanup() From 3789d49ebc2dda5b1d2cd22d04b7c0d48baf5640 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Tue, 14 Oct 2025 12:45:04 -0700 Subject: [PATCH 03/30] set up titan distributed through monarch utils, colocate policy with workers --- src/forge/actors/policy.py | 25 +++++++++++++++---------- src/forge/actors/trainer.py | 15 +-------------- src/forge/controller/provisioner.py | 10 +++++++++- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 8e8c8de17..447372b39 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -49,7 +49,12 @@ load_tensor_from_dcp, ) -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +from forge.controller import ( + ForgeActor, + get_proc_mesh, + host_mesh_from_proc, + stop_proc_mesh, +) from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA @@ -143,17 +148,16 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] ) worker_procs = await get_proc_mesh(process_config=process_config) - # TODO - issues/144 we will want to ensure colocation with workers - # We're currently locating the Policy on the local host proc mesh - # vLLM initialization without setting env variables at proc_mesh creation - # level leads to issues. - # Once we can create multiple proc meshes on a host mesh, we can ensure - # host colocation + # Grab a single host from the workers... + host_mesh = await host_mesh_from_proc(worker_procs) + singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()} + host_mesh = host_mesh.slice(**singleton_slice) policy_proc_config = copy(process_config) policy_proc_config.procs = 1 - policy_proc_config.hosts = None policy_proc_config.with_gpus = False - policy_proc = await get_proc_mesh(process_config=policy_proc_config) + policy_proc = await get_proc_mesh( + process_config=policy_proc_config, host_mesh=host_mesh + ) if isinstance(engine_args, Mapping): engine_args = EngineArgs(**engine_args) @@ -343,7 +347,8 @@ def _preprocess_add_request( self, request: EngineCoreRequest ) -> tuple[Request, int]: """(forge/issues/332) Will require attention when we bump vllm versions - https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 + """ if request.mm_hashes is not None: raise NotImplementedError("Support for mm_hash is not implemented yet.") req = Request.from_engine_core_request(request) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index dd85b3c82..23ab870eb 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import logging -import math import os import shutil @@ -18,7 +17,7 @@ import torch.distributed.checkpoint as dcp import torchstore as ts -from monarch.actor import current_rank, current_size, endpoint +from monarch.actor import endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict from torchtitan.config.job_config import ( @@ -163,19 +162,7 @@ def __post_init__(self): self.step = 1 # fragile contract. self.num_training_steps = self.training.steps self.gradient_accumulation_steps = 1 - self.rank = current_rank().rank - self.size = math.prod(current_size().values()) - env = { - "RANK": str(self.rank), - "LOCAL_RANK": str(self.rank), - "LOCAL_WORLD_SIZE": str(self.size), - "GROUP_RANK": str(self.size), - "GROUP_WORLD_SIZE": str(self.size), - "ROLE_RANK": str(self.rank), - "ROLE_WORLD_SIZE": str(self.size), - "ROLE_NAME": "rank", - "WORLD_SIZE": str(self.size), "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", } os.environ.update(env) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 1bb340328..ace380aab 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -18,8 +18,9 @@ from monarch.tools import commands -from forge.controller.launcher import BaseLauncher, get_launcher +from monarch.utils import setup_env_for_distributed +from forge.controller.launcher import BaseLauncher, get_launcher from forge.env import all_env_vars, FORGE_DISABLE_METRICS, MONARCH_HOSTMESH_V1 from forge.types import ProcessConfig, ProvisionerConfig @@ -294,6 +295,13 @@ def bootstrap(env: dict[str, str]): bootstrap=functools.partial(bootstrap, env=env_vars), ) + # Set up environment variables for PyTorch distributed... + await setup_env_for_distributed( + procs, + master_addr=addr, + master_port=port, + ) + if is_remote: await self.launcher.remote_setup(procs) From 23c8fcbb2245e5ac5a5c8e5ccce9239cf852456b Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 12:52:58 -0700 Subject: [PATCH 04/30] add SharedTensorHandle class --- src/forge/util/_shared_tensor.py | 35 ++++++++++++++------- tests/unit_tests/util/test_shared_tensor.py | 3 -- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py index b86cce249..4d1eb089e 100644 --- a/src/forge/util/_shared_tensor.py +++ b/src/forge/util/_shared_tensor.py @@ -6,17 +6,30 @@ import functools import uuid +from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch +@dataclass +class SharedTensorHandle: + shm_name: str + shape: Tuple[int, ...] + dtype: str + + class SharedTensor: """Wrapper class for tensors backed my shared memory""" - def __init__(self, *, tensor=None, handle=None): + def __init__( + self, + *, + tensor: Optional[torch.Tensor] = None, + handle: Optional[SharedTensorHandle] = None, + ): if tensor is not None: self._create_from_tensor(tensor) elif handle is not None: @@ -125,11 +138,11 @@ def _create_from_tensor(self, tensor): raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy() self._shm.buf[:byte_size] = raw_bytes - def _create_from_handle(self, handle): + def _create_from_handle(self, handle: SharedTensorHandle): """Initialize from a handle""" - self._shm_name = handle["shm_name"] - self._shape = handle["shape"] - self._dtype_str = handle["dtype"] + self._shm_name = handle.shm_name + self._shape = handle.shape + self._dtype_str = handle.dtype self._dtype = self._parse_dtype(self._dtype_str) # Attach to existing shared memory @@ -160,11 +173,11 @@ def _parse_dtype(self, dtype_str): def get_handle(self): """Get picklable handle""" - return { - "shm_name": self._shm_name, - "shape": self._shape, - "dtype": self._dtype_str, - } + return SharedTensorHandle( + shm_name=self._shm_name, + shape=self._shape, + dtype=self._dtype_str, + ) @functools.cached_property def tensor(self): diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py index 0383366dd..12f8006e5 100644 --- a/tests/unit_tests/util/test_shared_tensor.py +++ b/tests/unit_tests/util/test_shared_tensor.py @@ -248,9 +248,6 @@ def test_handle_is_picklable(self): unpickled_handle = pickle.loads(pickled) assert unpickled_handle == handle - assert unpickled_handle["shm_name"] == handle["shm_name"] - assert unpickled_handle["shape"] == handle["shape"] - assert unpickled_handle["dtype"] == handle["dtype"] shared.cleanup() From 15d27844328b7708fd3c139b953c20a02005c83f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 13:28:05 -0700 Subject: [PATCH 05/30] shared memory weight loading --- src/forge/actors/policy.py | 61 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 8e8c8de17..66eafbc80 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,6 +13,7 @@ from collections.abc import Mapping from copy import copy from dataclasses import dataclass, field +from typing import Optional import torch import torchstore as ts @@ -57,6 +58,7 @@ from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig +from forge.util._shared_tensor import SharedTensor, SharedTensorHandle logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -343,7 +345,8 @@ def _preprocess_add_request( self, request: EngineCoreRequest ) -> tuple[Request, int]: """(forge/issues/332) Will require attention when we bump vllm versions - https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 + """ if request.mm_hashes is not None: raise NotImplementedError("Support for mm_hash is not implemented yet.") req = Request.from_engine_core_request(request) @@ -399,6 +402,11 @@ async def update_weights(self, policy_version: int) -> None: >>> await trainer.push_weights() >>> policy.update_weights(version) """ + # Create a new thread because fetch is alloc heavy and blocking on CPU + # (however alloc doesn't acquire the GIL) + fetch_task = asyncio.create_task( + asyncio.to_thread(self._fetch_weights, policy_version) + ) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -429,7 +437,13 @@ async def update_weights(self, policy_version: int) -> None: logger.debug(f"Starting weight update on {self.__class__.__name__}") # Call update_weights on every policy_worker - await self.policy_worker.update_weights.call(policy_version) + t = Tracer("policy_perf/_waiting_for_fetch_weights") + t.start() + fetched_weights = await fetch_task + t.end() + await self.policy_worker.update_weights.call( + shared_memory_state_dict=fetched_weigths + ) self.policy_version = policy_version # After updating the weights, we need to reset the KV cache @@ -593,9 +607,50 @@ async def setup_kv_cache(self) -> KVCacheConfig: async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: return self.worker.execute_model(schedule) + async def _fetch_weights( + self, policy_version: int + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("policy_perf/_fetch_weights") + t.start() + prefix = get_param_prefix(policy_version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + shared_memory_state_dict = {} + for name in hf_param_names: + param_key = get_param_key(policy_version, name) + param = await ts.get(param_key) + # TODO: preallocate in the shared memory once we have plumbing in torchstore. + shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() + t.stop() + return shared_memory_state_dict + @endpoint - async def update_weights(self, policy_version: int) -> None: + async def update_weights( + self, + policy_version: Optional[int], + *, + shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, + ) -> None: model = self.worker.model_runner.model + if shared_memory_state_dict is not None: + logger.info("[PolicyWorker] update weights from shared memory.") + t = Tracer( + "policy_worker_perf/update_weights_from_shared_memory", timer="gpu" + ) + t.start() + for name, param_handle in shared_memory_state_dict.items(): + param = SharedTensor(handle=param_handle).tensor + loaded = model.load_weights([(name, param)]) + loaded_weights = set(loaded) + t.stop() + return + if policy_version is None: + raise ValueError( + "policy_version must be provided if not using shared_memory_state_dict" + ) prefix = get_param_prefix(policy_version) matching_keys = await ts.keys(prefix) dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(policy_version) From c7ca7385205db1c6f399034533b89d34af0d465e Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 13:43:02 -0700 Subject: [PATCH 06/30] oopsie --- src/forge/actors/policy.py | 40 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 66eafbc80..62b13ff81 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -388,6 +388,26 @@ async def run(self) -> None: if len(self.requests) == 0: self.request_lock.notify_all() + async def _fetch_weights( + self, policy_version: int + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("policy_perf/_fetch_weights") + t.start() + prefix = get_param_prefix(policy_version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + shared_memory_state_dict = {} + for name in hf_param_names: + param_key = get_param_key(policy_version, name) + param = await ts.get(param_key) + # TODO: preallocate in the shared memory once we have plumbing in torchstore. + shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() + t.stop() + return shared_memory_state_dict + @endpoint async def update_weights(self, policy_version: int) -> None: """Update weights on base model from a policy version to be found in a torchstore volume. @@ -607,26 +627,6 @@ async def setup_kv_cache(self) -> KVCacheConfig: async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: return self.worker.execute_model(schedule) - async def _fetch_weights( - self, policy_version: int - ) -> dict[str, SharedTensorHandle]: - """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" - t = Tracer("policy_perf/_fetch_weights") - t.start() - prefix = get_param_prefix(policy_version) - matching_keys = await ts.keys(prefix) - hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - shared_memory_state_dict = {} - for name in hf_param_names: - param_key = get_param_key(policy_version, name) - param = await ts.get(param_key) - # TODO: preallocate in the shared memory once we have plumbing in torchstore. - shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() - t.stop() - return shared_memory_state_dict - @endpoint async def update_weights( self, From cbd529e5823599722960a2374725b3dacf181f53 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 13:50:42 -0700 Subject: [PATCH 07/30] end -> stop --- src/forge/actors/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 62b13ff81..aaf4b400c 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -460,7 +460,7 @@ async def update_weights(self, policy_version: int) -> None: t = Tracer("policy_perf/_waiting_for_fetch_weights") t.start() fetched_weights = await fetch_task - t.end() + t.stop() await self.policy_worker.update_weights.call( shared_memory_state_dict=fetched_weigths ) From 20e162fb2dda47f4251739406476f267591e0d76 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 13:55:04 -0700 Subject: [PATCH 08/30] typo --- src/forge/actors/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index aaf4b400c..bcc672a50 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -462,7 +462,7 @@ async def update_weights(self, policy_version: int) -> None: fetched_weights = await fetch_task t.stop() await self.policy_worker.update_weights.call( - shared_memory_state_dict=fetched_weigths + shared_memory_state_dict=fetched_weights ) self.policy_version = policy_version From eadf3a5f08ccaeb3e74aa70745192cc1ac64f890 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 13:58:38 -0700 Subject: [PATCH 09/30] make policy_version optional --- src/forge/actors/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index bcc672a50..74878673a 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -630,7 +630,7 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: @endpoint async def update_weights( self, - policy_version: Optional[int], + policy_version: Optional[int] = None, *, shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, ) -> None: From 2b675f88b999dfddf117ea4b10d05b28c7e86bb0 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 14:12:18 -0700 Subject: [PATCH 10/30] fix --- src/forge/actors/policy.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 74878673a..aad0a1fa6 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -422,11 +422,9 @@ async def update_weights(self, policy_version: int) -> None: >>> await trainer.push_weights() >>> policy.update_weights(version) """ - # Create a new thread because fetch is alloc heavy and blocking on CPU - # (however alloc doesn't acquire the GIL) - fetch_task = asyncio.create_task( - asyncio.to_thread(self._fetch_weights, policy_version) - ) + # TODO: currently the alloc in ts.get will block the event loop unfortunately + # potentially we need to change torchstore + fetch_task = asyncio.create_task(self._fetch_weights(policy_version)) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests From 8b488b700eee95d78228c793f9db905501d3b77c Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 14:20:50 -0700 Subject: [PATCH 11/30] no leak --- src/forge/actors/policy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index aad0a1fa6..2f2b8b948 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -463,6 +463,8 @@ async def update_weights(self, policy_version: int) -> None: shared_memory_state_dict=fetched_weights ) self.policy_version = policy_version + for _, handle in fetched_weights.items(): + SharedTensor(handle=handle).cleanup() # After updating the weights, we need to reset the KV cache self.scheduler.reset_prefix_cache() From 5e7528cfb78c544a31f471c20ab5e7145e83fb93 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 21:22:48 +0000 Subject: [PATCH 12/30] disable dcp in 8b --- apps/grpo/qwen3_8b.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 9b2f70edd..534e5b92a 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -41,7 +41,6 @@ policy: # Trainer configuration trainer: - use_dcp: true model: name: qwen3 flavor: 8B From 2b23e500fbb6e06d8ff6e6abe270fda84e0d337b Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:16:05 -0700 Subject: [PATCH 13/30] undo the colocation --- src/forge/actors/policy.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 447372b39..036691f2a 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -155,9 +155,14 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] policy_proc_config = copy(process_config) policy_proc_config.procs = 1 policy_proc_config.with_gpus = False - policy_proc = await get_proc_mesh( - process_config=policy_proc_config, host_mesh=host_mesh - ) + + # TODO - not working yet, delete this once debugged + policy_proc_config.hosts = None + policy_proc = await get_proc_mesh(process_config=policy_proc_config) + + # policy_proc = await get_proc_mesh( + # process_config=policy_proc_config, host_mesh=host_mesh + # ) if isinstance(engine_args, Mapping): engine_args = EngineArgs(**engine_args) From 4b850ad4f9f7eee76ad5211714c4f84ba479f59a Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 14:29:46 -0700 Subject: [PATCH 14/30] debug info --- src/forge/actors/policy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 2f2b8b948..6a947a83d 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -641,10 +641,12 @@ async def update_weights( "policy_worker_perf/update_weights_from_shared_memory", timer="gpu" ) t.start() + loaded_weights = set() for name, param_handle in shared_memory_state_dict.items(): param = SharedTensor(handle=param_handle).tensor loaded = model.load_weights([(name, param)]) - loaded_weights = set(loaded) + loaded_weights.update(loaded) + logger.info(f"[PolicyWorker] update {len(loaded_weights)} paremeters") t.stop() return if policy_version is None: From f7cbcb47d5acea36e5b60267b9d32893e02363e8 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 15:17:10 -0700 Subject: [PATCH 15/30] typo --- src/forge/actors/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 6a947a83d..b902c2e1a 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -646,7 +646,7 @@ async def update_weights( param = SharedTensor(handle=param_handle).tensor loaded = model.load_weights([(name, param)]) loaded_weights.update(loaded) - logger.info(f"[PolicyWorker] update {len(loaded_weights)} paremeters") + logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") t.stop() return if policy_version is None: From 09836f3a865ca70c8e8339eadbb3e4c4fe59b49d Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 15:34:03 -0700 Subject: [PATCH 16/30] refactor --- src/forge/actors/policy.py | 4 +- src/forge/util/_shared_tensor.py | 25 ++- tests/unit_tests/util/test_shared_tensor.py | 191 ++++++++++++++++---- 3 files changed, 176 insertions(+), 44 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index b902c2e1a..3c321e74b 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -464,7 +464,7 @@ async def update_weights(self, policy_version: int) -> None: ) self.policy_version = policy_version for _, handle in fetched_weights.items(): - SharedTensor(handle=handle).cleanup() + handle.to_shared_tensor().drop() # After updating the weights, we need to reset the KV cache self.scheduler.reset_prefix_cache() @@ -643,7 +643,7 @@ async def update_weights( t.start() loaded_weights = set() for name, param_handle in shared_memory_state_dict.items(): - param = SharedTensor(handle=param_handle).tensor + param = param_handle.to_shared_tensor().tensor loaded = model.load_weights([(name, param)]) loaded_weights.update(loaded) logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py index 4d1eb089e..50290281e 100644 --- a/src/forge/util/_shared_tensor.py +++ b/src/forge/util/_shared_tensor.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import functools import uuid from dataclasses import dataclass @@ -20,6 +22,15 @@ class SharedTensorHandle: shape: Tuple[int, ...] dtype: str + def to_shared_tensor(self) -> SharedTensor: + """ + Create a SharedTensor from this handle. + + Returns: + SharedTensor instance attached to the shared memory referenced by this handle + """ + return SharedTensor(handle=self) + class SharedTensor: """Wrapper class for tensors backed my shared memory""" @@ -203,8 +214,18 @@ def clone(self): new_shared.tensor.copy_(self.tensor) return new_shared - def cleanup(self): - """Clean up shared memory""" + def drop(self): + """ + Release and unlink the shared memory. + + This method closes the shared memory handle and removes the shared memory + segment from the system. After calling this method, the shared memory + will no longer be accessible by any process. + + Note: + This should be called when the shared tensor is no longer needed. + Failing to call this method may result in shared memory leaks. + """ try: self._shm.close() self._shm.unlink() diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py index 12f8006e5..094567449 100644 --- a/tests/unit_tests/util/test_shared_tensor.py +++ b/tests/unit_tests/util/test_shared_tensor.py @@ -30,7 +30,7 @@ def test_empty_creation(self): assert shared.tensor.shape == torch.Size(shape) assert shared.tensor.dtype == dtype - shared.cleanup() + shared.drop() def test_empty_with_bfloat16(self): """Test creating empty bfloat16 tensor""" @@ -40,7 +40,7 @@ def test_empty_with_bfloat16(self): assert shared.tensor.dtype == torch.bfloat16 assert shared.tensor.dtype == torch.bfloat16 - shared.cleanup() + shared.drop() def test_zeros_creation(self): """Test creating zero-initialized tensor""" @@ -51,7 +51,7 @@ def test_zeros_creation(self): assert torch.all(tensor == 0) assert tensor.sum().item() == 0.0 - shared.cleanup() + shared.drop() def test_ones_creation(self): """Test creating ones-initialized tensor""" @@ -62,7 +62,7 @@ def test_ones_creation(self): assert torch.all(tensor == 1) assert tensor.sum().item() == 200.0 - shared.cleanup() + shared.drop() def test_from_tensor_creation(self): """Test creating from existing tensor""" @@ -73,7 +73,7 @@ def test_from_tensor_creation(self): assert shared.tensor.dtype == original.dtype assert torch.allclose(shared.tensor, original) - shared.cleanup() + shared.drop() def test_from_handle_creation(self): """Test creating from handle""" @@ -91,7 +91,7 @@ def test_from_handle_creation(self): assert reconstructed.tensor.shape == original.tensor.shape assert reconstructed.tensor.dtype == original.tensor.dtype - original.cleanup() + original.drop() def test_creation_requires_argument(self): """Test that creation without arguments raises error""" @@ -112,7 +112,7 @@ def test_various_shapes(self, shape): shared = SharedTensor.empty(shape, torch.float32) assert shared.tensor.shape == torch.Size(shape) assert shared.tensor.shape == torch.Size(shape) - shared.cleanup() + shared.drop() class TestSharedTensorDtypes: @@ -149,7 +149,7 @@ def test_all_dtypes(self, dtype): else: shared.tensor.fill_(3.14) - shared.cleanup() + shared.drop() def test_dtype_conversion_in_handle(self): """Test dtype is preserved through handle""" @@ -160,7 +160,7 @@ def test_dtype_conversion_in_handle(self): shared2 = SharedTensor(handle=handle) assert shared2.tensor.dtype == dtype - shared1.cleanup() + shared1.drop() class TestSharedTensorOperations: @@ -174,7 +174,7 @@ def test_copy_from(self): shared.copy_from(source) assert torch.allclose(shared.tensor, source) - shared.cleanup() + shared.drop() def test_copy_from_shape_mismatch(self): """Test copy_from raises error on shape mismatch""" @@ -184,7 +184,7 @@ def test_copy_from_shape_mismatch(self): with pytest.raises(ValueError, match="Shape mismatch"): shared.copy_from(source) - shared.cleanup() + shared.drop() def test_clone(self): """Test cloning creates independent copy""" @@ -201,8 +201,8 @@ def test_clone(self): assert torch.all(cloned.tensor == 5.0) assert torch.all(original.tensor == 10.0) - original.cleanup() - cloned.cleanup() + original.drop() + cloned.drop() def test_tensor_modifications(self): """Test that modifications to tensor are reflected""" @@ -217,7 +217,7 @@ def test_tensor_modifications(self): assert tensor2[0, 0].item() == 99.0 assert torch.all(tensor2[5:, :] == 42.0) - shared.cleanup() + shared.drop() def test_inplace_operations(self): """Test in-place operations work""" @@ -232,7 +232,7 @@ def test_inplace_operations(self): assert abs(new_mean - (mean + 5.0)) < 0.1 - shared.cleanup() + shared.drop() class TestSharedTensorSerialization: @@ -249,7 +249,7 @@ def test_handle_is_picklable(self): assert unpickled_handle == handle - shared.cleanup() + shared.drop() def test_handle_small_size(self): """Test that handle is small (efficient for RPC)""" @@ -261,7 +261,7 @@ def test_handle_small_size(self): # Handle should be < 1KB even for huge tensors assert len(pickled) < 1024 - shared.cleanup() + shared.drop() def test_data_integrity_after_pickle(self): """Test data is preserved through handle pickling""" @@ -281,14 +281,14 @@ def test_data_integrity_after_pickle(self): # Verify data is same assert torch.allclose(shared2.tensor.float(), original_data.float(), rtol=1e-3) - shared1.cleanup() + shared1.drop() class TestSharedTensorMemory: """Test memory management and cleanup""" - def test_cleanup(self): - """Test cleanup removes shared memory""" + def test_drop(self): + """Test drop removes shared memory""" shared = SharedTensor.empty((10, 10), torch.float32) shm_name = shared._shm_name @@ -296,8 +296,8 @@ def test_cleanup(self): tensor = shared.tensor tensor.fill_(5.0) - # Cleanup - shared.cleanup() + # Drop shared memory + shared.drop() # Trying to attach should fail from multiprocessing import shared_memory @@ -321,7 +321,7 @@ def test_multiple_views_same_memory(self): # Verify tensor1 sees the change assert torch.all(tensor1 == 10.0) - shared.cleanup() + shared.drop() def test_handle_reconstruction_shares_memory(self): """Test that handle reconstruction shares same memory""" @@ -337,7 +337,7 @@ def test_handle_reconstruction_shares_memory(self): # Verify shared1 sees the change assert torch.all(shared1.tensor == 14.0) - shared1.cleanup() + shared1.drop() class TestSharedTensorEdgeCases: @@ -354,14 +354,14 @@ def test_empty_shape(self): (), ), ) - shared.cleanup() + shared.drop() def test_single_element_tensor(self): """Test 1-element tensor""" shared = SharedTensor.empty((1,), torch.float32) shared.tensor.fill_(42.0) assert shared.tensor.item() == 42.0 - shared.cleanup() + shared.drop() def test_large_tensor(self): """Test large tensor (1GB)""" @@ -372,7 +372,7 @@ def test_large_tensor(self): assert shared.tensor.shape == shape assert shared.tensor.numel() == 250_000_000 - shared.cleanup() + shared.drop() def test_non_contiguous_tensor_conversion(self): """Test that non-contiguous tensors are handled""" @@ -386,7 +386,7 @@ def test_non_contiguous_tensor_conversion(self): # Result should match assert torch.allclose(shared.tensor, original) - shared.cleanup() + shared.drop() def test_repr(self): """Test string representation""" @@ -398,7 +398,7 @@ def test_repr(self): assert "float32" in repr_str assert shared._shm_name in repr_str - shared.cleanup() + shared.drop() class TestSharedTensorMultiprocess: @@ -429,7 +429,7 @@ def reader_process(handle_dict, result_queue): assert abs(result - expected) < 1e-5 - shared.cleanup() + shared.drop() def test_multiprocess_write(self): """Test writing to shared tensor from another process""" @@ -452,7 +452,7 @@ def writer_process(handle_dict, value): # Verify in main process assert torch.all(shared.tensor == 42.0) - shared.cleanup() + shared.drop() def test_multiprocess_bidirectional(self): """Test bidirectional communication""" @@ -485,8 +485,8 @@ def worker_process(input_handle, output_handle): output_shared.tensor, expected ), "output: {}, expected: {}".format(output_shared.tensor, expected) - input_shared.cleanup() - output_shared.cleanup() + input_shared.drop() + output_shared.drop() class TestSharedTensorPerformance: @@ -500,7 +500,7 @@ def test_empty_faster_than_from_tensor(self): start = time.time() for _ in range(10): shared = SharedTensor.empty(shape, torch.float32) - shared.cleanup() + shared.drop() empty_time = time.time() - start # Time from_tensor creation @@ -508,7 +508,7 @@ def test_empty_faster_than_from_tensor(self): for _ in range(10): tensor = torch.randn(shape) shared = SharedTensor(tensor=tensor) - shared.cleanup() + shared.drop() from_tensor_time = time.time() - start # empty() should be faster (no data copying) @@ -528,7 +528,118 @@ def test_handle_serialization_fast(self): # Should be able to do 1000 round trips in < 0.1 seconds assert elapsed < 0.1 - shared.cleanup() + shared.drop() + + +class TestSharedTensorHandleToSharedTensor: + """Test SharedTensorHandle.to_shared_tensor() method""" + + def test_to_shared_tensor_basic(self): + """Test basic creation of SharedTensor from handle using to_shared_tensor method""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(7.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.all(reconstructed.tensor == 7.0) + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype + + original.drop() + + def test_to_shared_tensor_preserves_data(self): + """Test that to_shared_tensor preserves original data""" + original = SharedTensor.empty((20, 30), torch.float32) + original.tensor.normal_(0, 1) + original_data = original.tensor.clone() + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.allclose(reconstructed.tensor, original_data) + + original.drop() + + def test_to_shared_tensor_shares_memory(self): + """Test that to_shared_tensor shares memory with original""" + original = SharedTensor.empty((15, 15), torch.float32) + original.tensor.fill_(5.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + reconstructed.tensor.fill_(10.0) + + assert torch.all(original.tensor == 10.0) + + original.drop() + + def test_to_shared_tensor_with_various_dtypes(self): + """Test to_shared_tensor works with different data types""" + for dtype in [torch.float32, torch.float64, torch.bfloat16, torch.int32]: + original = SharedTensor.empty((5, 5), dtype) + if ( + dtype == torch.bfloat16 + or dtype == torch.float32 + or dtype == torch.float64 + ): + original.tensor.normal_(0, 1) + else: + original.tensor.fill_(42) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert reconstructed.tensor.dtype == dtype + if dtype == torch.bfloat16: + assert torch.allclose( + reconstructed.tensor.float(), original.tensor.float(), rtol=1e-3 + ) + else: + assert torch.allclose(reconstructed.tensor, original.tensor) + + original.drop() + + def test_to_shared_tensor_multiprocess(self): + """Test to_shared_tensor in multiprocess scenario""" + + def worker_process(handle, result_queue): + shared = handle.to_shared_tensor() + result_queue.put(shared.tensor.sum().item()) + + original = SharedTensor.empty((50, 50), torch.float32) + original.tensor.fill_(3.0) + + handle = original.get_handle() + result_queue = Queue() + + p = Process(target=worker_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 3.0 * 50 * 50 + + assert abs(result - expected) < 1e-5 + + original.drop() + + def test_to_shared_tensor_equivalent_to_constructor(self): + """Test that handle.to_shared_tensor() is equivalent to SharedTensor(handle=handle)""" + original = SharedTensor.empty((25, 25), torch.float32) + original.tensor.normal_(0, 1) + + handle = original.get_handle() + + via_method = handle.to_shared_tensor() + via_constructor = SharedTensor(handle=handle) + + assert torch.allclose(via_method.tensor, via_constructor.tensor) + assert via_method.tensor.shape == via_constructor.tensor.shape + assert via_method.tensor.dtype == via_constructor.tensor.dtype + + original.drop() class TestSharedTensorBfloat16: @@ -538,7 +649,7 @@ def test_bfloat16_creation(self): """Test bfloat16 tensor creation""" shared = SharedTensor.empty((100, 100), torch.bfloat16) assert shared.tensor.dtype == torch.bfloat16 - shared.cleanup() + shared.drop() def test_bfloat16_from_tensor(self): """Test creating shared tensor from bfloat16 tensor""" @@ -548,7 +659,7 @@ def test_bfloat16_from_tensor(self): assert shared.tensor.dtype == torch.bfloat16 assert torch.allclose(shared.tensor.float(), original.float(), rtol=1e-3) - shared.cleanup() + shared.drop() def test_bfloat16_handle_preservation(self): """Test bfloat16 dtype preserved through handle""" @@ -561,7 +672,7 @@ def test_bfloat16_handle_preservation(self): assert shared2.tensor.dtype == torch.bfloat16 assert torch.allclose(shared1.tensor.float(), shared2.tensor.float(), rtol=1e-3) - shared1.cleanup() + shared1.drop() def test_bfloat16_operations(self): """Test operations on bfloat16 tensors""" @@ -574,7 +685,7 @@ def test_bfloat16_operations(self): # Mean should be close to 0 assert abs(mean) < 0.1 - shared.cleanup() + shared.drop() if __name__ == "__main__": From 9fa7395501031f6cc842e75c6ffcfd3c4b913cb9 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 15:42:50 -0700 Subject: [PATCH 17/30] temp: reduce num_replicas to 2 --- apps/grpo/qwen3_32b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index e7a0cf509..e41ed4527 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -119,7 +119,7 @@ ref_model: services: policy: procs: ${policy.engine_args.tensor_parallel_size} - num_replicas: 4 + num_replicas: 2 hosts: 1 with_gpus: true mesh_name: policy From 7744f4cd08deb11c7ac18f3118b2147b6d7e1724 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 16:17:40 -0700 Subject: [PATCH 18/30] fix bad merge --- src/forge/actors/generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 8a50e0e54..73ef5e757 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -459,10 +459,10 @@ async def update_weights(self, version: int) -> None: t.start() fetched_weights = await fetch_task t.stop() - await self.worker.update_weights.call( + await self.generator_worker.update_weights.call( shared_memory_state_dict=fetched_weights ) - self.version = version + self.generator_version = version for _, handle in fetched_weights.items(): handle.to_shared_tensor().drop() From 8970ff42af038e3b291155992f486de46434478f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 16:18:47 -0700 Subject: [PATCH 19/30] revert to 4 --- apps/grpo/qwen3_32b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index e41ed4527..e7a0cf509 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -119,7 +119,7 @@ ref_model: services: policy: procs: ${policy.engine_args.tensor_parallel_size} - num_replicas: 2 + num_replicas: 4 hosts: 1 with_gpus: true mesh_name: policy From 48821de4cc7e9f56c8cb368e6d01b0097ff146c6 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 17:15:21 -0700 Subject: [PATCH 20/30] move _fetch_weights to policy worker --- src/forge/actors/generator.py | 65 ++++++++++++++++++-------------- src/forge/util/_shared_tensor.py | 2 +- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 73ef5e757..8e845b095 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -388,24 +388,6 @@ async def run(self) -> None: if len(self.requests) == 0: self.request_lock.notify_all() - async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]: - """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" - t = Tracer("policy_perf/_fetch_weights") - t.start() - prefix = get_param_prefix(version) - matching_keys = await ts.keys(prefix) - hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - shared_memory_state_dict = {} - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - # TODO: preallocate in the shared memory once we have plumbing in torchstore. - shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() - t.stop() - return shared_memory_state_dict - @endpoint async def update_weights(self, version: int) -> None: """Update weights on base model from a generator version to be found in a torchstore volume. @@ -422,7 +404,11 @@ async def update_weights(self, version: int) -> None: """ # TODO: currently the alloc in ts.get will block the event loop unfortunately # potentially we need to change torchstore - fetch_task = asyncio.create_task(self._fetch_weights(version)) + # TODO: move this logic to Generator once we can make sure Generator and GeneratorWorker are on the same host. + if not self.use_dcp: + fetch_task = asyncio.create_task( + self.policy_worker._fetch_weights.choose(version) + ) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -454,17 +440,20 @@ async def update_weights(self, version: int) -> None: ) logger.debug(f"Starting weight update on {self.__class__.__name__}") - # Call update_weights on every policy_worker - t = Tracer("generator_perf/_waiting_for_fetch_weights") - t.start() - fetched_weights = await fetch_task - t.stop() - await self.generator_worker.update_weights.call( - shared_memory_state_dict=fetched_weights - ) + if not self.use_dcp: + t = Tracer("generator_perf/_waiting_for_fetch_weights") + t.start() + fetched_weights = await fetch_task + t.stop() + # Call update_weights on every policy_worker + await self.generator_worker.update_weights.call( + shared_memory_state_dict=fetched_weights + ) + for _, handle in fetched_weights.items(): + handle.to_shared_tensor().drop() + else: + await self.generator_worker.update_weights.call(version=version) self.generator_version = version - for _, handle in fetched_weights.items(): - handle.to_shared_tensor().drop() # After updating the weights, we need to reset the KV cache self.scheduler.reset_prefix_cache() @@ -680,6 +669,24 @@ async def update_weights( loaded_weights.update(loaded) t.stop() + async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("policy_perf/_fetch_weights") + t.start() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + shared_memory_state_dict = {} + for name in hf_param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + # TODO: preallocate in the shared memory once we have plumbing in torchstore. + shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() + t.stop() + return shared_memory_state_dict + @endpoint async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py index 50290281e..7839fe2d0 100644 --- a/src/forge/util/_shared_tensor.py +++ b/src/forge/util/_shared_tensor.py @@ -156,7 +156,7 @@ def _create_from_handle(self, handle: SharedTensorHandle): self._dtype_str = handle.dtype self._dtype = self._parse_dtype(self._dtype_str) - # Attach to existing shared memory + # Attach to existing shared memory\ self._shm = shared_memory.SharedMemory(name=self._shm_name) def _create_tensor_view(self): From 2798f2e6b78b730e3cb59e3e19729d51a7924edf Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 17:19:26 -0700 Subject: [PATCH 21/30] typo --- src/forge/actors/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 8e845b095..c85c824ce 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -407,7 +407,7 @@ async def update_weights(self, version: int) -> None: # TODO: move this logic to Generator once we can make sure Generator and GeneratorWorker are on the same host. if not self.use_dcp: fetch_task = asyncio.create_task( - self.policy_worker._fetch_weights.choose(version) + self.generator_worker._fetch_weights.choose(version) ) # Serialize updates (only one update at a time) async with self.update_lock: From ddf8d26a40d0a2f4d1a52777f0392a50cf8ea57c Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 17:24:42 -0700 Subject: [PATCH 22/30] endpoint --- src/forge/actors/generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index c85c824ce..59a728f2f 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -669,6 +669,7 @@ async def update_weights( loaded_weights.update(loaded) t.stop() + @endpoint async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]: """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" t = Tracer("policy_perf/_fetch_weights") From 71b89c19c77e6362604fcd9e99ae313d47b4cedc Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 17:28:21 -0700 Subject: [PATCH 23/30] clean up --- src/forge/actors/generator.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 59a728f2f..b8c6be14d 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -627,7 +627,7 @@ async def update_weights( if shared_memory_state_dict is not None: logger.info("[PolicyWorker] update weights from shared memory.") t = Tracer( - "policy_worker_perf/update_weights_from_shared_memory", timer="gpu" + "generator_worker_perf/update_weights_from_shared_memory", timer="gpu" ) t.start() loaded_weights = set() @@ -642,11 +642,12 @@ async def update_weights( raise ValueError( "version must be provided if not using shared_memory_state_dict" ) + # If shared memory is not provided, we assume we are using DCP prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) loaded_weights = set() - t = Tracer("worker_perf/update_weights", timer="gpu") + t = Tracer("generator_worker_perf/update_weights", timer="gpu") t.start() # Entire state dict is stored in a single DCP handle if dcp_whole_state_dict_key in matching_keys: @@ -657,22 +658,14 @@ async def update_weights( loaded = model.load_weights([(name, param)]) del param loaded_weights.update(loaded) - else: # Load each parameter from torchstore directly without DCP - hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) + else: + raise RuntimeError("No DCP handle found for the given version") t.stop() @endpoint async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]: """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" - t = Tracer("policy_perf/_fetch_weights") + t = Tracer("generator_worker_perf/_fetch_weights") t.start() prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) From 1971a4fd25c6c5157b11f8013beaf499adf2fed7 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 17:31:42 -0700 Subject: [PATCH 24/30] fix --- src/forge/actors/generator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index b8c6be14d..430cfbcc7 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -406,8 +406,12 @@ async def update_weights(self, version: int) -> None: # potentially we need to change torchstore # TODO: move this logic to Generator once we can make sure Generator and GeneratorWorker are on the same host. if not self.use_dcp: + # We have to do this because Monarch future is not directly compatible with asyncio + async def fetch_coro(): + return await self.generator_worker._fetch_weights.choose(version) + fetch_task = asyncio.create_task( - self.generator_worker._fetch_weights.choose(version) + fetch_coro(), ) # Serialize updates (only one update at a time) async with self.update_lock: From dc301aa00933529d5c0c02833e0c4b72fdd32f1d Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 18:01:36 -0700 Subject: [PATCH 25/30] fix --- src/forge/actors/generator.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 430cfbcc7..8ac4c8310 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -453,8 +453,9 @@ async def fetch_coro(): await self.generator_worker.update_weights.call( shared_memory_state_dict=fetched_weights ) - for _, handle in fetched_weights.items(): - handle.to_shared_tensor().drop() + # This can not be dropped in Generator because it is not on the same host as GeneratorWorker + # TODO: move this logic to Generator once we can make sure Generator and GeneratorWorker are on the same host. + await self.generator_worker._drop_shared_memory.choose(fetched_weights) else: await self.generator_worker.update_weights.call(version=version) self.generator_version = version @@ -685,6 +686,14 @@ async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]: t.stop() return shared_memory_state_dict + @endpoint + async def _drop_shared_memory( + self, shared_memory_state_dict: dict[str, SharedTensorHandle] + ): + """Drop shared memory tensors after use.""" + for _, handle in shared_memory_state_dict.items(): + handle.to_shared_tensor().drop() + @endpoint async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" From c87975302f22e66815cdc2c35c90799606f95144 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 18:23:41 -0700 Subject: [PATCH 26/30] rearrange --- src/forge/actors/generator.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 8ac4c8310..116034ffb 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -402,17 +402,6 @@ async def update_weights(self, version: int) -> None: >>> await trainer.push_weights() >>> generator.update_weights(version) """ - # TODO: currently the alloc in ts.get will block the event loop unfortunately - # potentially we need to change torchstore - # TODO: move this logic to Generator once we can make sure Generator and GeneratorWorker are on the same host. - if not self.use_dcp: - # We have to do this because Monarch future is not directly compatible with asyncio - async def fetch_coro(): - return await self.generator_worker._fetch_weights.choose(version) - - fetch_task = asyncio.create_task( - fetch_coro(), - ) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -445,9 +434,14 @@ async def fetch_coro(): logger.debug(f"Starting weight update on {self.__class__.__name__}") if not self.use_dcp: - t = Tracer("generator_perf/_waiting_for_fetch_weights") + # TODO: currently the alloc in ts.get will block the event loop unfortunately + # potentially we need to change torchstore + # We have to do this because Monarch future is not directly compatible with asyncio + t = Tracer("generator_perf/_fetch_weights") t.start() - fetched_weights = await fetch_task + fetched_weights = await self.generator_worker._fetch_weights.choose( + version + ) t.stop() # Call update_weights on every policy_worker await self.generator_worker.update_weights.call( From c462911a949233a632ee03105332771d653efbae Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 18:40:53 -0700 Subject: [PATCH 27/30] log --- src/forge/actors/generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 116034ffb..20b74847b 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -439,6 +439,9 @@ async def update_weights(self, version: int) -> None: # We have to do this because Monarch future is not directly compatible with asyncio t = Tracer("generator_perf/_fetch_weights") t.start() + logger.info( + f"[Generator] Fetching weights for v{version} to shared memory" + ) fetched_weights = await self.generator_worker._fetch_weights.choose( version ) From 571750f28d797b2f1c04b54197e8770a7ad6b56e Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Wed, 15 Oct 2025 02:04:25 +0000 Subject: [PATCH 28/30] vllm colocation works --- src/forge/actors/policy.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 036691f2a..efed34ef1 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -156,13 +156,12 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] policy_proc_config.procs = 1 policy_proc_config.with_gpus = False - # TODO - not working yet, delete this once debugged - policy_proc_config.hosts = None - policy_proc = await get_proc_mesh(process_config=policy_proc_config) - - # policy_proc = await get_proc_mesh( - # process_config=policy_proc_config, host_mesh=host_mesh - # ) + # By passing in the host_mesh here, we will get a new proc + # spawned on the provided host_mesh. Since that host mesh is + # taken from the policy_proc, this ensures colocation. + policy_proc = await get_proc_mesh( + process_config=policy_proc_config, host_mesh=host_mesh + ) if isinstance(engine_args, Mapping): engine_args = EngineArgs(**engine_args) From a155a4c90f5415b381f94eab22111fad49d02d80 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 14 Oct 2025 22:31:28 -0700 Subject: [PATCH 29/30] caching allocation for weight updates --- src/forge/actors/generator.py | 67 ++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 20b74847b..f4e3a972d 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -9,6 +9,7 @@ import asyncio import logging import os +import queue import sys from collections.abc import Mapping from copy import copy @@ -206,6 +207,11 @@ async def setup(self): self.request_lock = asyncio.Condition() # Guard for accepting_requests self.update_lock = asyncio.Condition() # Guard for updating requests + # Shared memory allocated for weight updates + self.cached_state_dict_allocs: queue.Queue[ + dict[str, SharedTensorHandle] + ] = queue.Queue(maxsize=2) + vllm_config: VllmConfig = self.engine_args.create_engine_config( UsageContext.LLM_CLASS ) @@ -442,8 +448,12 @@ async def update_weights(self, version: int) -> None: logger.info( f"[Generator] Fetching weights for v{version} to shared memory" ) + try: + pre_allocated = self.cached_state_dict_allocs.get_nowait() + except queue.Empty: + pre_allocated = None fetched_weights = await self.generator_worker._fetch_weights.choose( - version + version, pre_allocated=pre_allocated ) t.stop() # Call update_weights on every policy_worker @@ -452,7 +462,13 @@ async def update_weights(self, version: int) -> None: ) # This can not be dropped in Generator because it is not on the same host as GeneratorWorker # TODO: move this logic to Generator once we can make sure Generator and GeneratorWorker are on the same host. - await self.generator_worker._drop_shared_memory.choose(fetched_weights) + try: + self.cached_state_dict_allocs.put_nowait(fetched_weights) + except queue.Full: + logger.info( + "Cached state dict alloc queue is full. Dropping allocated state dict." + ) + self.policy_worker._drop_shared_memory(fetched_weights) else: await self.generator_worker.update_weights.call(version=version) self.generator_version = version @@ -480,6 +496,18 @@ async def get_version(self) -> int: async def stop(self): self.running = False + @endpoint + async def _cleanup_shared_memory(self): + """Cleanup shared memory allocated for weight updates.""" + while not self.cached_state_dict_allocs.empty(): + try: + state_dict = self.cached_state_dict_allocs.get_nowait() + self.policy_worker._drop_shared_memory(state_dict) + except queue.Empty: + logger.info( + "Cached state dict alloc queue is empty. No state dict to drop." + ) + def _to_completions(self, request_output: RequestOutput) -> list[Completion]: """Convert a vLLM RequestOutput to a list of Completion objects.""" completions = [] @@ -526,6 +554,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] # TODO - may want to expand stop to gracefully respond to # ongoing requests. await actor.stop.call() + await actor._cleanup_shared_memory.call() await stop_proc_mesh(actor._worker_procs) await stop_proc_mesh(actor._generator_proc) @@ -665,7 +694,12 @@ async def update_weights( t.stop() @endpoint - async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]: + async def _fetch_weights( + self, + version: int, + *, + pre_allocated: Optional[dict[str, SharedTensorHandle]] = None, + ) -> dict[str, SharedTensorHandle]: """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" t = Tracer("generator_worker_perf/_fetch_weights") t.start() @@ -675,11 +709,26 @@ async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]: # We can't pass a generator since vllm load_weights is not async. # Instead, we just call load_weights with one parameter at a time. shared_memory_state_dict = {} - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - # TODO: preallocate in the shared memory once we have plumbing in torchstore. - shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() + if pre_allocated is not None: + logger.info( + "[Generator] fetching weights from torchstore to shared memory. Using pre allocated shared memory." + ) + shared_memory_state_dict = pre_allocated + assert set(shared_memory_state_dict.keys()) == set( + hf_param_names + ), "The pre_allocated dict must have the same keys as hf_param_names" + for name, handle in shared_memory_state_dict.items(): + param_key = get_param_key(version, name) + param = handle.to_shared_tensor().tensor + await ts.get(param_key, inplace_tensor=param) + else: + logger.info( + "[Generator] fetching weights from torchstore to shared memory." + ) + for name in hf_param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() t.stop() return shared_memory_state_dict @@ -688,7 +737,7 @@ async def _drop_shared_memory( self, shared_memory_state_dict: dict[str, SharedTensorHandle] ): """Drop shared memory tensors after use.""" - for _, handle in shared_memory_state_dict.items(): + for handle in shared_memory_state_dict.values(): handle.to_shared_tensor().drop() @endpoint From 9fa32973c44c66a44fad74fbafeeccc36dffc5e2 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 15 Oct 2025 00:17:53 -0700 Subject: [PATCH 30/30] test --- src/forge/actors/generator.py | 147 ++++++++++++++++------------------ 1 file changed, 69 insertions(+), 78 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index f90d6e252..adbc6f3dc 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -155,15 +155,15 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] host_mesh = await host_mesh_from_proc(worker_procs) singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()} host_mesh = host_mesh.slice(**singleton_slice) - policy_proc_config = copy(process_config) - policy_proc_config.procs = 1 - policy_proc_config.with_gpus = False + generator_proc_config = copy(process_config) + generator_proc_config.procs = 1 + generator_proc_config.with_gpus = False # By passing in the host_mesh here, we will get a new proc # spawned on the provided host_mesh. Since that host mesh is # taken from the policy_proc, this ensures colocation. - policy_proc = await get_proc_mesh( - process_config=policy_proc_config, host_mesh=host_mesh + generator_proc = await get_proc_mesh( + process_config=generator_proc_config, host_mesh=host_mesh ) if isinstance(engine_args, Mapping): @@ -260,6 +260,59 @@ def _start_processing(self): if self._run_task is None or self._run_task.done(): self._run_task = asyncio.create_task(self.run()) + async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): + for handle in state_dict.values(): + handle.drop() + + async def _cleanup_shared_memory(self): + """Cleanup shared memory allocated for weight updates.""" + while not self.cached_state_dict_allocs.empty(): + try: + state_dict = self.cached_state_dict_allocs.get_nowait() + await self._drop_shared_memory(state_dict) + except queue.Empty: + logger.info( + "Cached state dict alloc queue is empty. No state dict to drop." + ) + + async def _fetch_weights( + self, + version: int, + *, + pre_allocated: Optional[dict[str, SharedTensorHandle]] = None, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("generator_perf/_fetch_weights") + t.start() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + shared_memory_state_dict = {} + if pre_allocated is not None: + logger.info( + "[Generator] fetching weights from torchstore to shared memory. Using pre allocated shared memory." + ) + shared_memory_state_dict = pre_allocated + assert set(shared_memory_state_dict.keys()) == set( + hf_param_names + ), "The pre_allocated dict must have the same keys as hf_param_names" + for name, handle in shared_memory_state_dict.items(): + param_key = get_param_key(version, name) + param = handle.to_shared_tensor().tensor + await ts.get(param_key, inplace_tensor=param) + else: + logger.info( + "[Generator] fetching weights from torchstore to shared memory." + ) + for name in hf_param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() + t.stop() + return shared_memory_state_dict + @endpoint async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt @@ -416,6 +469,14 @@ async def update_weights(self, version: int) -> None: >>> await trainer.push_weights() >>> generator.update_weights(version) """ + logger.info(f"[Generator] Fetching weights for v{version} to shared memory") + try: + pre_allocated = self.cached_state_dict_allocs.get_nowait() + except queue.Empty: + pre_allocated = None + fetch_task = asyncio.create_task( + self._fetch_weights(version, pre_allocated=pre_allocated) + ) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -451,32 +512,21 @@ async def update_weights(self, version: int) -> None: # TODO: currently the alloc in ts.get will block the event loop unfortunately # potentially we need to change torchstore # We have to do this because Monarch future is not directly compatible with asyncio - t = Tracer("generator_perf/_fetch_weights") + t = Tracer("generator_perf/waiting_for_fetch_weights") t.start() - logger.info( - f"[Generator] Fetching weights for v{version} to shared memory" - ) - try: - pre_allocated = self.cached_state_dict_allocs.get_nowait() - except queue.Empty: - pre_allocated = None - fetched_weights = await self.generator_worker._fetch_weights.choose( - version, pre_allocated=pre_allocated - ) + fetched_weights = await fetch_task t.stop() # Call update_weights on every policy_worker await self.generator_worker.update_weights.call( shared_memory_state_dict=fetched_weights ) - # This can not be dropped in Generator because it is not on the same host as GeneratorWorker - # TODO: move this logic to Generator once we can make sure Generator and GeneratorWorker are on the same host. try: self.cached_state_dict_allocs.put_nowait(fetched_weights) except queue.Full: logger.info( "Cached state dict alloc queue is full. Dropping allocated state dict." ) - self.policy_worker._drop_shared_memory(fetched_weights) + await self._drop_shared_memory(fetched_weights) else: await self.generator_worker.update_weights.call(version=version) self.generator_version = version @@ -504,18 +554,6 @@ async def get_version(self) -> int: async def stop(self): self.running = False - @endpoint - async def _cleanup_shared_memory(self): - """Cleanup shared memory allocated for weight updates.""" - while not self.cached_state_dict_allocs.empty(): - try: - state_dict = self.cached_state_dict_allocs.get_nowait() - self.policy_worker._drop_shared_memory(state_dict) - except queue.Empty: - logger.info( - "Cached state dict alloc queue is empty. No state dict to drop." - ) - def _to_completions(self, request_output: RequestOutput) -> list[Completion]: """Convert a vLLM RequestOutput to a list of Completion objects.""" completions = [] @@ -701,53 +739,6 @@ async def update_weights( raise RuntimeError("No DCP handle found for the given version") t.stop() - @endpoint - async def _fetch_weights( - self, - version: int, - *, - pre_allocated: Optional[dict[str, SharedTensorHandle]] = None, - ) -> dict[str, SharedTensorHandle]: - """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" - t = Tracer("generator_worker_perf/_fetch_weights") - t.start() - prefix = get_param_prefix(version) - matching_keys = await ts.keys(prefix) - hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - shared_memory_state_dict = {} - if pre_allocated is not None: - logger.info( - "[Generator] fetching weights from torchstore to shared memory. Using pre allocated shared memory." - ) - shared_memory_state_dict = pre_allocated - assert set(shared_memory_state_dict.keys()) == set( - hf_param_names - ), "The pre_allocated dict must have the same keys as hf_param_names" - for name, handle in shared_memory_state_dict.items(): - param_key = get_param_key(version, name) - param = handle.to_shared_tensor().tensor - await ts.get(param_key, inplace_tensor=param) - else: - logger.info( - "[Generator] fetching weights from torchstore to shared memory." - ) - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() - t.stop() - return shared_memory_state_dict - - @endpoint - async def _drop_shared_memory( - self, shared_memory_state_dict: dict[str, SharedTensorHandle] - ): - """Drop shared memory tensors after use.""" - for handle in shared_memory_state_dict.values(): - handle.to_shared_tensor().drop() - @endpoint async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only."""