diff --git a/pyproject.toml b/pyproject.toml index 61cd9700c..1d0ff0cf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "tomli>=1.1.0", "anyio", "pytest-asyncio", + "multiprocess", ] oss = [ "torch", diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 0dc385cc0..41e2590a9 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -13,10 +13,12 @@ from collections.abc import Mapping from copy import copy from dataclasses import dataclass, field +from typing import Optional import torch import torchstore as ts -from monarch.actor import current_rank, endpoint, ProcMesh +from monarch.actor import current_rank, endpoint, ProcMesh, this_host + from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -60,6 +62,7 @@ from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig +from forge.util._shared_tensor import SharedTensor, SharedTensorHandle logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -92,6 +95,8 @@ class Generator(ForgeActor): engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) use_dcp_for_weight_sync: bool | None = None + prefetch_weights_to_shm: bool = True + n_fetcher_procs: int = 8 def __post_init__(self): super().__init__() @@ -226,11 +231,61 @@ async def setup(self): log_stats=None, ) self._start_processing() + if self.prefetch_weights_to_shm: + self._spawn_fetchers() + + def _spawn_fetchers(self): + """Spawn weight fetchers that prefetch weights from torchstore to shared memory.""" + # TODO: this assumes the generator is on the same host as the worker + # and only works for single host generators. Figure out how to support + # generators with workers spanned across multiple hosts. + fetcher_procs = this_host().spawn_procs( + per_host={"procs": self.n_fetcher_procs} + ) + self._fetcher_procs = fetcher_procs + self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher) def _start_processing(self): if self._run_task is None or self._run_task.done(): self._run_task = asyncio.create_task(self.run()) + async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): + for handle in state_dict.values(): + handle.drop() + + async def _fetch_weights( + self, + version: int, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("generator_perf/_fetch_weights") + t.start() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + + n_fetchers = self.weight_fetchers.size() + + def split_keys(keys): + return [keys[i::n_fetchers] for i in range(n_fetchers)] + + futures = [] + for i, names in enumerate(split_keys(hf_param_names)): + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( + version=version, param_names=names + ) + futures.append(fut) + + sub_state_dicts = [await fut for fut in futures] + + state_dict = {} + for sd in sub_state_dicts: + state_dict.update(sd) + + t.stop() + + return state_dict + @endpoint async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt @@ -384,6 +439,12 @@ async def update_weights(self, version: int) -> None: >>> await trainer.push_weights() >>> generator.update_weights(version) """ + # TODO: enable shared memory prefetch for DCP-based weight sync + if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: + logger.info(f"[Generator] Fetching weights for v{version} to shared memory") + fetch_fut = asyncio.create_task(self._fetch_weights(version)) + else: + fetch_fut = None # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -415,8 +476,19 @@ async def update_weights(self, version: int) -> None: ) logger.debug(f"Starting weight update on {self.__class__.__name__}") - # Call update_weights on every generator worker - await self.worker.update_weights.call(version=version) + + if fetch_fut is not None: + t = Tracer("generator_perf/waiting_for_fetch_weights") + t.start() + fetched_weights = await fetch_fut + t.stop() + # Call update_weights on every policy_worker + await self.worker.update_weights.call( + shared_memory_state_dict=fetched_weights + ) + await self._drop_shared_memory(fetched_weights) + else: + await self.worker.update_weights.call(version=version) self.generator_version = version # After updating the weights, we need to reset the KV cache @@ -490,6 +562,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] await actor.stop.call() await stop_proc_mesh(actor._worker_procs) await stop_proc_mesh(actor._generator_proc) + await stop_proc_mesh(actor._fetcher_procs) @endpoint async def save_model_params(self): @@ -569,14 +642,42 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: return self.worker.execute_model(schedule) @endpoint - async def update_weights(self, version: int) -> None: + async def update_weights( + self, + version: Optional[int] = None, + *, + shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, + ) -> None: model = self.worker.model_runner.model + if shared_memory_state_dict is not None: + logger.info("[PolicyWorker] update weights from shared memory.") + t = Tracer( + "generator_worker_perf/update_weights_from_shared_memory", timer="gpu" + ) + t.start() + loaded_weights = set() + for name, param_handle in shared_memory_state_dict.items(): + # Use context manager for automatic cleanup + with param_handle.to_shared_tensor() as shared_tensor: + param = shared_tensor.tensor + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") + t.stop() + return + # normal update_weights without shared memory prefetching + if version is None: + raise ValueError( + "version must be provided if not using shared_memory_state_dict" + ) + logger.info("[PolicyWorker] update weights from torchstore.") prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys loaded_weights = set() - t = Tracer("worker_perf/update_weights", timer="gpu") + t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu") t.start() if use_dcp_for_weight_sync: @@ -617,3 +718,27 @@ async def validate_model_params(self, validate_fn): return validate_fn( self._debug_saved_params, self.worker.model_runner.model, logger ) + + +class _WeightFetcher(ForgeActor): + """Fetches weights from torchstore and loads them into shared memory. + This has to be colocated with the GeneratorWorker.""" + + @endpoint + async def fetch( + self, + *, + version: int, + param_names: list[str], + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and load them into shared memory.""" + sd = {} + for name in param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + # Use context manager to ensure cleanup after getting handle + with SharedTensor(tensor=param) as shared_tensor: + handle = shared_tensor.get_handle() + sd[name] = handle + del param # Explicitly free the tensor after copying to shared memory + return sd diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py new file mode 100644 index 000000000..18a7d65e6 --- /dev/null +++ b/src/forge/util/_shared_tensor.py @@ -0,0 +1,440 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import logging + +import uuid +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@dataclass +class SharedTensorHandle: + shm_name: str + shape: Tuple[int, ...] + dtype: str + + def to_shared_tensor(self) -> SharedTensor: + """ + Create a SharedTensor from this handle. + + Returns: + SharedTensor instance attached to the shared memory referenced by this handle + """ + return SharedTensor(handle=self) + + def drop(self) -> None: + """ + Unlink the shared memory segment. + + This marks the shared memory for deletion. The actual memory will be freed + once all processes have closed their handles to it. + + Note: This only unlinks, it does not close any handles. Processes that have + opened this shared memory should call close() on their SharedTensor instances. + """ + try: + # Attach to the shared memory just to unlink it + shm = shared_memory.SharedMemory(name=self.shm_name) + shm.close() + shm.unlink() + except Exception: + pass + + +class SharedTensor: + """ + Wrapper class for tensors backed by shared memory. + + This class provides a way to share tensors between processes using POSIX shared memory. + It's designed for efficient inter-process tensor communication without copying data. + + Ownership and Lifecycle Model: + ------------------------------ + 1. **Creator process**: + - Creates SharedTensor with tensor data or empty + - Gets a handle via get_handle() to pass to other processes + - **MUST** call close() after getting handle to release its reference + - **SHOULD** call drop()/unlink() when all processes are done + + 2. **Receiver processes**: + - Receive SharedTensorHandle (via RPC, pickle, etc.) + - Create SharedTensor from handle: SharedTensor(handle=handle) + - Use the tensor: handle.to_shared_tensor().tensor + - **MUST** call close() when done using the tensor + + 3. **Cleanup**: + - close(): Closes this process's file descriptor/handle + - drop()/unlink(): Marks shared memory for deletion (call once, from any process) + - Actual memory is freed when all processes have closed AND unlink is called + + Memory Leak Prevention: + ---------------------- + - **DO NOT** rely on __del__ for cleanup! Python GC is unpredictable. + - **ALWAYS** explicitly call close() when done with a SharedTensor + - **ALWAYS** call drop() on handles when sharing is complete + - Use context manager (with statement) for automatic cleanup + - After close(), accessing .tensor will raise RuntimeError + - After close(), getting handle will raise RuntimeError + + Closed State Behavior: + --------------------- + - Once close() is called, the SharedTensor enters a closed state + - Accessing .tensor after close() raises RuntimeError + - Calling get_handle() after close() raises RuntimeError + - You can check the state with the .is_closed property + - close() and drop() are idempotent (safe to call multiple times) + + Important Warning: + ------------------ + If you hold a reference to the tensor BEFORE calling close(), that + reference becomes INVALID after close(): + t = shared.tensor # Get reference + shared.close() # Close SharedTensor - unmaps memory + t.sum() # SEGFAULT! The memory is now invalid + + After close(), the shared memory mapping is unmapped, so ALL references + to the tensor (including cached ones) point to invalid memory. Accessing + them will cause segmentation faults or undefined behavior. + + Always ensure you're done with the tensor before calling close(). + + Example Usage: + ------------- + # Creator process + tensor = torch.randn(100, 100) + shared = SharedTensor(tensor=tensor) + handle = shared.get_handle() + shared.close() # Close creator's reference + # ... send handle to other process via RPC ... + handle.drop() # Unlink after all receivers have it + + # Receiver process + # ... receive handle via RPC ... + shared = SharedTensor(handle=handle) + result = shared.tensor.sum() # Use the tensor + shared.close() # Close receiver's reference + + # Or use context manager (recommended) + with SharedTensor(handle=handle) as shared: + result = shared.tensor.sum() + # Automatically closed + """ + + def __init__( + self, + *, + tensor: Optional[torch.Tensor] = None, + handle: Optional[SharedTensorHandle] = None, + ): + if tensor is not None: + self._create_from_tensor(tensor) + elif handle is not None: + self._create_from_handle(handle) + else: + raise ValueError("Must provide either tensor or handle") + + @classmethod + def empty( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create an empty tensor directly in shared memory (no copy/allocation overhead) + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype (supports bfloat16, float32, etc.) + + Returns: + SharedTensor instance with uninitialized data + """ + instance = cls.__new__(cls) + instance._create_empty(shape, dtype) + return instance + + @classmethod + def zeros( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a zero-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with zeros + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.zero_() + return shared_tensor + + @classmethod + def ones( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a ones-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with ones + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.fill_(1) + return shared_tensor + + def _create_empty(self, shape, dtype): + """Initialize with empty tensor in shared memory""" + # Initialize lifecycle state + self._closed = False + self._tensor_cache = None + + # Store metadata + self._shape = tuple(shape) if not isinstance(shape, tuple) else shape + self._dtype = dtype + self._dtype_str = str(dtype) + + # Calculate size + element_size = torch.tensor([], dtype=dtype).element_size() + total_elements = int(np.prod(self._shape)) + byte_size = total_elements * element_size + + # Create shared memory (uninitialized - fast!) + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + self._shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self._shm_name = shm_name + + def _create_from_tensor(self, tensor): + """Initialize from an existing tensor""" + # Initialize lifecycle state + self._closed = False + self._tensor_cache = None + + tensor = tensor.contiguous() + + # Store metadata + self._shape = tuple(tensor.shape) + self._dtype = tensor.dtype + self._dtype_str = str(tensor.dtype) + + # Create shared memory + byte_size = tensor.numel() * tensor.element_size() + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + + self._shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self._shm_name = shm_name + + # Copy data as raw bytes + raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy() + self._shm.buf[:byte_size] = raw_bytes + del raw_bytes # Explicitly free the intermediate numpy array + + def _create_from_handle(self, handle: SharedTensorHandle): + """Initialize from a handle""" + # Initialize lifecycle state + self._closed = False + self._tensor_cache = None + + self._shm_name = handle.shm_name + self._shape = handle.shape + self._dtype_str = handle.dtype + self._dtype = self._parse_dtype(self._dtype_str) + + # Attach to existing shared memory\ + self._shm = shared_memory.SharedMemory(name=self._shm_name) + + def _create_tensor_view(self): + """Create tensor view of shared memory.""" + element_size = torch.tensor([], dtype=self._dtype).element_size() + total_elements = int(np.prod(self._shape)) + byte_size = total_elements * element_size + + # Create numpy array that shares the buffer + np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self._shm.buf) + # Create torch tensor from numpy (shares memory) + uint8_tensor = torch.from_numpy(np_array) + tensor = uint8_tensor.view(self._dtype).reshape(self._shape) + + # Keep the np array alive + tensor._forge_np_array = np_array + + return tensor + + def _parse_dtype(self, dtype_str): + """Parse dtype string""" + dtype_str = dtype_str.replace("torch.", "") + return getattr(torch, dtype_str) + + def get_handle(self): + """ + Get a picklable handle to share this SharedTensor with other processes. + + Returns: + SharedTensorHandle: A lightweight handle that can be pickled and sent to other processes + + Raises: + RuntimeError: If called after close() has been called + """ + if self._closed: + raise RuntimeError( + "Cannot get handle after close(). Get the handle before closing." + ) + return SharedTensorHandle( + shm_name=self._shm_name, + shape=self._shape, + dtype=self._dtype_str, + ) + + @property + def tensor(self): + """ + Get the underlying tensor. + + Returns: + torch.Tensor: View into the shared memory + + Raises: + RuntimeError: If accessed after close() has been called + """ + if self._closed: + raise RuntimeError( + "Cannot access tensor after close(). The SharedTensor has been closed." + ) + if self._tensor_cache is None: + self._tensor_cache = self._create_tensor_view() + return self._tensor_cache + + def copy_from(self, source_tensor): + """ + Copy data from another tensor into this shared tensor + Useful when you create empty tensor first, then fill it + + Args: + source_tensor: Source tensor to copy from + """ + if source_tensor.shape != self._shape: + raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self._shape}") + # Copy data + self.tensor.copy_(source_tensor) + + def clone(self): + """Create a new SharedTensor with copied data""" + new_shared = SharedTensor.empty(self._shape, self._dtype) + new_shared.tensor.copy_(self.tensor) + return new_shared + + def close(self): + """ + Close this process's handle to the shared memory. + + This should be called when this process is done using the shared memory. + The shared memory will persist until all processes have closed their handles + and someone calls unlink(). + + After calling close(), this SharedTensor object should not be used anymore. + Accessing the tensor property after close() will raise a RuntimeError. + + This method is idempotent - calling it multiple times is safe. + + Note: If you hold a reference to the tensor before calling close(), + that reference will remain valid, but new accesses via shared.tensor + will raise an error. + """ + if self._closed: + return # Already closed, nothing to do + + self._closed = True + self._tensor_cache = None # Release tensor and numpy array references + + try: + self._shm.close() + except Exception as e: + logger.error(f"Error closing shared memory {self._shm_name}: {e}") + + def drop(self): + """ + Close and unlink the shared memory. + + This method first closes this process's handle (if not already closed), + then marks the shared memory for deletion. The actual memory will be freed + once all processes have closed their handles. + + This method is idempotent - calling it multiple times is safe. + + Note: + This should be called when the shared tensor is no longer needed. + Failing to call this method may result in shared memory leaks. + """ + # Close first to set _closed flag and release cache + self.close() + + # Then unlink + try: + self._shm.unlink() + except Exception as e: + raise RuntimeError( + f"Error unlinking shared memory {self._shm_name}: {e}" + ) from e + + @property + def is_closed(self) -> bool: + """ + Check if this SharedTensor has been closed. + + Returns: + bool: True if close() has been called, False otherwise + """ + return self._closed + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - closes the shared memory handle.""" + self.close() + return False + + def __del__(self): + """ + Best-effort cleanup on garbage collection. + + WARNING: Do NOT rely on __del__ for cleanup! Python's garbage collector + may not call __del__ promptly or at all, which can cause memory leaks. + Always explicitly call close() when done with the SharedTensor. + + This __del__ is only a safety net for cases where explicit cleanup is missed. + """ + # Only close if the object was fully initialized + if hasattr(self, "_closed"): + self.close() + + def __repr__(self): + return f"SharedTensor(shape={self._shape}, dtype={self._dtype}, shm_name={self._shm_name})" diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py new file mode 100644 index 000000000..f922c3733 --- /dev/null +++ b/tests/unit_tests/util/test_shared_tensor.py @@ -0,0 +1,905 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +import time + +import pytest +import torch + +# Assuming SharedTensor is in shared_tensor.py +from forge.util._shared_tensor import SharedTensor +from multiprocess import Process, Queue + + +class TestSharedTensorCreation: + """Test tensor creation methods""" + + def test_empty_creation(self): + """Test creating empty tensor""" + shape = (100, 200) + dtype = torch.float32 + + shared = SharedTensor.empty(shape, dtype) + + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype + + shared.drop() + + def test_empty_with_bfloat16(self): + """Test creating empty bfloat16 tensor""" + shape = (50, 50) + shared = SharedTensor.empty(shape, torch.bfloat16) + + assert shared.tensor.dtype == torch.bfloat16 + assert shared.tensor.dtype == torch.bfloat16 + + shared.drop() + + def test_zeros_creation(self): + """Test creating zero-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.zeros(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 0) + assert tensor.sum().item() == 0.0 + + shared.drop() + + def test_ones_creation(self): + """Test creating ones-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.ones(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 1) + assert tensor.sum().item() == 200.0 + + shared.drop() + + def test_from_tensor_creation(self): + """Test creating from existing tensor""" + original = torch.randn(50, 50) + shared = SharedTensor(tensor=original) + + assert shared.tensor.shape == original.shape + assert shared.tensor.dtype == original.dtype + assert torch.allclose(shared.tensor, original) + + shared.drop() + + def test_from_handle_creation(self): + """Test creating from handle""" + # Create original + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + # Get handle + handle = original.get_handle() + + # Create from handle + reconstructed = SharedTensor(handle=handle) + + assert torch.all(reconstructed.tensor == 5.0) + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype + + original.drop() + + def test_creation_requires_argument(self): + """Test that creation without arguments raises error""" + with pytest.raises(ValueError, match="Must provide either tensor or handle"): + SharedTensor() + + @pytest.mark.parametrize( + "shape", + [ + (10,), + (10, 20), + (5, 10, 15), + (2, 3, 4, 5), + ], + ) + def test_various_shapes(self, shape): + """Test creation with various shapes""" + shared = SharedTensor.empty(shape, torch.float32) + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.shape == torch.Size(shape) + shared.drop() + + +class TestSharedTensorDtypes: + """Test all supported dtypes""" + + @pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + ], + ) + def test_all_dtypes(self, dtype): + """Test that all dtypes work correctly""" + shape = (10, 10) + shared = SharedTensor.empty(shape, dtype) + + assert shared.tensor.dtype == dtype + assert shared.tensor.dtype == dtype + + # Test that we can write to it + if dtype == torch.bool: + shared.tensor.fill_(True) + elif dtype in [torch.int32, torch.int64, torch.int16, torch.int8, torch.uint8]: + shared.tensor.fill_(42) + else: + shared.tensor.fill_(3.14) + + shared.drop() + + def test_dtype_conversion_in_handle(self): + """Test dtype is preserved through handle""" + for dtype in [torch.float32, torch.bfloat16, torch.int64]: + shared1 = SharedTensor.empty((5, 5), dtype) + handle = shared1.get_handle() + + shared2 = SharedTensor(handle=handle) + assert shared2.tensor.dtype == dtype + + shared1.drop() + + +class TestSharedTensorOperations: + """Test tensor operations""" + + def test_copy_from(self): + """Test copying data from another tensor""" + source = torch.randn(20, 30) + shared = SharedTensor.empty((20, 30), torch.float32) + + shared.copy_from(source) + + assert torch.allclose(shared.tensor, source) + shared.drop() + + def test_copy_from_shape_mismatch(self): + """Test copy_from raises error on shape mismatch""" + source = torch.randn(10, 10) + shared = SharedTensor.empty((20, 20), torch.float32) + + with pytest.raises(ValueError, match="Shape mismatch"): + shared.copy_from(source) + + shared.drop() + + def test_clone(self): + """Test cloning creates independent copy""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + cloned = original.clone() + + # Verify data is same + assert torch.all(cloned.tensor == 5.0) + + # Verify they're independent + original.tensor.fill_(10.0) + assert torch.all(cloned.tensor == 5.0) + assert torch.all(original.tensor == 10.0) + + original.drop() + cloned.drop() + + def test_tensor_modifications(self): + """Test that modifications to tensor are reflected""" + shared = SharedTensor.zeros((10, 10), torch.float32) + tensor = shared.tensor + + tensor[0, 0] = 99.0 + tensor[5:, :] = 42.0 + + # Get tensor again and verify changes persist + tensor2 = shared.tensor + assert tensor2[0, 0].item() == 99.0 + assert torch.all(tensor2[5:, :] == 42.0) + + shared.drop() + + def test_inplace_operations(self): + """Test in-place operations work""" + shared = SharedTensor.empty((100, 100), torch.float32) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.mean().item() + + tensor.add_(5.0) + new_mean = tensor.mean().item() + + assert abs(new_mean - (mean + 5.0)) < 0.1 + + shared.drop() + + +class TestSharedTensorSerialization: + """Test pickling and handle serialization""" + + def test_handle_is_picklable(self): + """Test that handle can be pickled""" + shared = SharedTensor.empty((10, 10), torch.float32) + handle = shared.get_handle() + + # Pickle and unpickle + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + assert unpickled_handle == handle + + shared.drop() + + def test_handle_small_size(self): + """Test that handle is small (efficient for RPC)""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + pickled = pickle.dumps(handle) + + # Handle should be < 1KB even for huge tensors + assert len(pickled) < 1024 + + shared.drop() + + def test_data_integrity_after_pickle(self): + """Test data is preserved through handle pickling""" + # Create and fill tensor + shared1 = SharedTensor.empty((50, 50), torch.bfloat16) + shared1.tensor.normal_(0, 1) + original_data = shared1.tensor.clone() + + # Pickle handle + handle = shared1.get_handle() + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + # Reconstruct + shared2 = SharedTensor(handle=unpickled_handle) + + # Verify data is same + assert torch.allclose(shared2.tensor.float(), original_data.float(), rtol=1e-3) + + shared1.drop() + + +class TestSharedTensorMemory: + """Test memory management and cleanup""" + + def test_drop(self): + """Test drop removes shared memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + shm_name = shared._shm_name + + # Verify shared memory exists + tensor = shared.tensor + tensor.fill_(5.0) + + # Drop shared memory + shared.drop() + + # Trying to attach should fail + from multiprocessing import shared_memory + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=shm_name) + + def test_multiple_views_same_memory(self): + """Test multiple tensor views point to same memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + + tensor1 = shared.tensor + tensor1.fill_(5.0) + + tensor2 = shared.tensor + assert torch.all(tensor2 == 5.0) + + # Modify through tensor2 + tensor2.fill_(10.0) + + # Verify tensor1 sees the change + assert torch.all(tensor1 == 10.0) + + shared.drop() + + def test_handle_reconstruction_shares_memory(self): + """Test that handle reconstruction shares same memory""" + shared1 = SharedTensor.empty((20, 20), torch.float32) + shared1.tensor.fill_(7.0) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + # Modify through shared2 + shared2.tensor.fill_(14.0) + + # Verify shared1 sees the change + assert torch.all(shared1.tensor == 14.0) + + shared1.drop() + + +class TestSharedTensorEdgeCases: + """Test edge cases and error conditions""" + + def test_empty_shape(self): + """Test scalar tensor (empty shape)""" + shared = SharedTensor.ones((), torch.float32) + assert shared.tensor.shape == () + assert shared.tensor.numel() == 1 + assert torch.allclose( + shared.tensor, + torch.ones( + (), + ), + ) + shared.drop() + + def test_single_element_tensor(self): + """Test 1-element tensor""" + shared = SharedTensor.empty((1,), torch.float32) + shared.tensor.fill_(42.0) + assert shared.tensor.item() == 42.0 + shared.drop() + + def test_large_tensor(self): + """Test large tensor (1GB)""" + # 1GB tensor: 250M float32 elements + shape = (250_000_000,) + shared = SharedTensor.empty(shape, torch.float32) + + assert shared.tensor.shape == shape + assert shared.tensor.numel() == 250_000_000 + + shared.drop() + + def test_non_contiguous_tensor_conversion(self): + """Test that non-contiguous tensors are handled""" + # Create non-contiguous tensor + original = torch.randn(10, 10).t() # Transpose makes it non-contiguous + assert not original.is_contiguous() + + # Should work (internally makes contiguous) + shared = SharedTensor(tensor=original) + + # Result should match + assert torch.allclose(shared.tensor, original) + + shared.drop() + + def test_repr(self): + """Test string representation""" + shared = SharedTensor.empty((10, 20), torch.float32) + repr_str = repr(shared) + + assert "SharedTensor" in repr_str + assert "10, 20" in repr_str + assert "float32" in repr_str + assert shared._shm_name in repr_str + + shared.drop() + + +class TestSharedTensorMultiprocess: + """Test multiprocess scenarios""" + + def test_multiprocess_read(self): + """Test reading shared tensor from another process""" + + def reader_process(handle_dict, result_queue): + with SharedTensor(handle=handle_dict) as shared: + result_queue.put(shared.tensor.sum().item()) + + # Create shared tensor in main process + shared = SharedTensor.empty((100, 100), torch.float32) + shared.tensor.fill_(5.0) + + # Read from child process + result_queue = Queue() + handle = shared.get_handle() + + p = Process(target=reader_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 5.0 * 100 * 100 + + assert abs(result - expected) < 1e-5 + + shared.drop() + + def test_multiprocess_write(self): + """Test writing to shared tensor from another process""" + + def writer_process(handle_dict, value): + with SharedTensor(handle=handle_dict) as shared: + shared.tensor.fill_(value) + + # Create empty shared tensor + shared = SharedTensor.empty((50, 50), torch.float32) + shared.tensor.zero_() + + # Write from child process + handle = shared.get_handle() + + p = Process(target=writer_process, args=(handle, 42.0)) + p.start() + p.join() + + # Verify in main process + assert torch.all(shared.tensor == 42.0) + + shared.drop() + + def test_multiprocess_bidirectional(self): + """Test bidirectional communication""" + + def worker_process(input_handle, output_handle): + with SharedTensor(handle=input_handle) as input_shared: + with SharedTensor(handle=output_handle) as output_shared: + # Compute: output = input * 2 + output_shared.tensor.copy_(input_shared.tensor * 2) + + # Create input and output tensors + input_shared = SharedTensor.empty((100, 100), torch.float32) + input_shared.tensor.normal_(0, 1) + input_data = input_shared.tensor.clone() + + output_shared = SharedTensor.empty((100, 100), torch.float32) + + # Process in child + p = Process( + target=worker_process, + args=(input_shared.get_handle(), output_shared.get_handle()), + ) + p.start() + p.join() + + # Verify result + expected = input_data * 2 + assert torch.allclose( + output_shared.tensor, expected + ), "output: {}, expected: {}".format(output_shared.tensor, expected) + + input_shared.drop() + output_shared.drop() + + +class TestSharedTensorPerformance: + """Performance-related tests""" + + def test_empty_faster_than_from_tensor(self): + """Test that empty() is faster than from tensor""" + shape = (1000, 1000) + + # Time empty creation + start = time.time() + for _ in range(10): + shared = SharedTensor.empty(shape, torch.float32) + shared.drop() + empty_time = time.time() - start + + # Time from_tensor creation + start = time.time() + for _ in range(10): + tensor = torch.randn(shape) + shared = SharedTensor(tensor=tensor) + shared.drop() + from_tensor_time = time.time() - start + + # empty() should be faster (no data copying) + assert empty_time < from_tensor_time + + def test_handle_serialization_fast(self): + """Test that handle serialization is fast""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + start = time.time() + for _ in range(1000): + pickled = pickle.dumps(handle) + unpickled = pickle.loads(pickled) + elapsed = time.time() - start + + # Should be able to do 1000 round trips in < 0.1 seconds + assert elapsed < 0.1 + + shared.drop() + + +class TestSharedTensorHandleToSharedTensor: + """Test SharedTensorHandle.to_shared_tensor() method""" + + def test_to_shared_tensor_basic(self): + """Test basic creation of SharedTensor from handle using to_shared_tensor method""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(7.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.all(reconstructed.tensor == 7.0) + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype + + original.drop() + + def test_to_shared_tensor_preserves_data(self): + """Test that to_shared_tensor preserves original data""" + original = SharedTensor.empty((20, 30), torch.float32) + original.tensor.normal_(0, 1) + original_data = original.tensor.clone() + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.allclose(reconstructed.tensor, original_data) + + original.drop() + + def test_to_shared_tensor_shares_memory(self): + """Test that to_shared_tensor shares memory with original""" + original = SharedTensor.empty((15, 15), torch.float32) + original.tensor.fill_(5.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + reconstructed.tensor.fill_(10.0) + + assert torch.all(original.tensor == 10.0) + + original.drop() + + def test_to_shared_tensor_with_various_dtypes(self): + """Test to_shared_tensor works with different data types""" + for dtype in [torch.float32, torch.float64, torch.bfloat16, torch.int32]: + original = SharedTensor.empty((5, 5), dtype) + if ( + dtype == torch.bfloat16 + or dtype == torch.float32 + or dtype == torch.float64 + ): + original.tensor.normal_(0, 1) + else: + original.tensor.fill_(42) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert reconstructed.tensor.dtype == dtype + if dtype == torch.bfloat16: + assert torch.allclose( + reconstructed.tensor.float(), original.tensor.float(), rtol=1e-3 + ) + else: + assert torch.allclose(reconstructed.tensor, original.tensor) + + original.drop() + + def test_to_shared_tensor_multiprocess(self): + """Test to_shared_tensor in multiprocess scenario""" + + def worker_process(handle, result_queue): + with handle.to_shared_tensor() as shared: + result_queue.put(shared.tensor.sum().item()) + + original = SharedTensor.empty((50, 50), torch.float32) + original.tensor.fill_(3.0) + + handle = original.get_handle() + result_queue = Queue() + + p = Process(target=worker_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 3.0 * 50 * 50 + + assert abs(result - expected) < 1e-5 + + original.drop() + + def test_to_shared_tensor_equivalent_to_constructor(self): + """Test that handle.to_shared_tensor() is equivalent to SharedTensor(handle=handle)""" + original = SharedTensor.empty((25, 25), torch.float32) + original.tensor.normal_(0, 1) + + handle = original.get_handle() + + via_method = handle.to_shared_tensor() + via_constructor = SharedTensor(handle=handle) + + assert torch.allclose(via_method.tensor, via_constructor.tensor) + assert via_method.tensor.shape == via_constructor.tensor.shape + assert via_method.tensor.dtype == via_constructor.tensor.dtype + + original.drop() + + +class TestSharedTensorBfloat16: + """Specific tests for bfloat16 support""" + + def test_bfloat16_creation(self): + """Test bfloat16 tensor creation""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + assert shared.tensor.dtype == torch.bfloat16 + shared.drop() + + def test_bfloat16_from_tensor(self): + """Test creating shared tensor from bfloat16 tensor""" + original = torch.randn(50, 50, dtype=torch.bfloat16) + shared = SharedTensor(tensor=original) + + assert shared.tensor.dtype == torch.bfloat16 + assert torch.allclose(shared.tensor.float(), original.float(), rtol=1e-3) + + shared.drop() + + def test_bfloat16_handle_preservation(self): + """Test bfloat16 dtype preserved through handle""" + shared1 = SharedTensor.empty((20, 20), torch.bfloat16) + shared1.tensor.normal_(0, 1) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + assert shared2.tensor.dtype == torch.bfloat16 + assert torch.allclose(shared1.tensor.float(), shared2.tensor.float(), rtol=1e-3) + + shared1.drop() + + def test_bfloat16_operations(self): + """Test operations on bfloat16 tensors""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.float().mean().item() + + # Mean should be close to 0 + assert abs(mean) < 0.1 + + shared.drop() + + +class TestSharedTensorCloseAndCleanup: + """Test explicit close() and cleanup patterns to prevent memory leaks""" + + def test_close_method(self): + """Test explicit close() releases handle and sets closed state""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(5.0) + + assert not shared.is_closed + + # Close should not raise + shared.close() + + assert shared.is_closed + + # Cleanup + shared._shm.unlink() + + def test_tensor_access_after_close_raises_error(self): + """Test that accessing tensor after close raises RuntimeError""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(5.0) + + shared.close() + + with pytest.raises(RuntimeError, match="Cannot access tensor after close"): + _ = shared.tensor + + # Cleanup + shared._shm.unlink() + + def test_get_handle_after_close_raises_error(self): + """Test that getting handle after close raises RuntimeError""" + shared = SharedTensor.empty((10, 10), torch.float32) + + shared.close() + + with pytest.raises(RuntimeError, match="Cannot get handle after close"): + shared.get_handle() + + # Cleanup + shared._shm.unlink() + + def test_is_closed_property(self): + """Test is_closed property reflects state correctly""" + shared = SharedTensor.empty((10, 10), torch.float32) + + assert not shared.is_closed + + shared.close() + + assert shared.is_closed + + # Cleanup + shared._shm.unlink() + + def test_cached_tensor_reference_becomes_invalid_after_close(self): + """Test that tensor reference obtained before close becomes invalid after close""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(5.0) + + # Get reference before close + tensor_ref = shared.tensor + + shared.close() + + # After close(), the memory mapping is unmapped, so even cached references + # point to invalid memory. Accessing them will cause segfault or undefined behavior. + # We can't safely test this, but we document it. + + # Accessing via shared.tensor raises error (this is what we CAN test) + with pytest.raises(RuntimeError): + _ = shared.tensor + + # Cleanup + shared._shm.unlink() + + def test_context_manager(self): + """Test context manager automatically closes""" + shm_name = None + + with SharedTensor.empty((10, 10), torch.float32) as shared: + shm_name = shared._shm_name + shared.tensor.fill_(7.0) + assert torch.all(shared.tensor == 7.0) + + # After exiting context, should be closed (but not unlinked yet) + # We need to unlink separately + from multiprocessing import shared_memory + + # Should still be able to attach (not unlinked) + shm = shared_memory.SharedMemory(name=shm_name) + shm.close() + shm.unlink() + + def test_creator_receiver_workflow(self): + """Test proper workflow: creator creates, gets handle, closes, receiver uses and closes""" + + def receiver_process(handle, result_queue): + # Receiver creates SharedTensor from handle + with SharedTensor(handle=handle) as shared: + result = shared.tensor.sum().item() + result_queue.put(result) + # Context manager auto-closes + + # Creator process + shared = SharedTensor.empty((50, 50), torch.float32) + shared.tensor.fill_(4.0) + handle = shared.get_handle() + shared.close() # Creator closes its reference + + # Pass to receiver + result_queue = Queue() + p = Process(target=receiver_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + assert abs(result - (4.0 * 50 * 50)) < 1e-5 + + # Unlink after all processes done + handle.drop() + + def test_handle_drop_without_creating_shared_tensor(self): + """Test that handle.drop() doesn't create unnecessary SharedTensor instance""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(3.0) + handle = shared.get_handle() + shared.close() + + # drop() should work without creating new SharedTensor + handle.drop() + + # Memory should be unlinked + from multiprocessing import shared_memory + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=handle.shm_name) + + def test_multiple_receivers_close_independently(self): + """Test that multiple receivers can close independently""" + + def receiver_process(handle, value, result_queue): + with SharedTensor(handle=handle) as shared: + result = shared.tensor[0, 0].item() == value + result_queue.put(result) + + # Creator + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(9.0) + handle = shared.get_handle() + shared.close() + + # Multiple receivers + result_queue = Queue() + processes = [] + for _ in range(3): + p = Process(target=receiver_process, args=(handle, 9.0, result_queue)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # All should succeed + for _ in range(3): + assert result_queue.get() is True + + # Cleanup + handle.drop() + + def test_close_is_idempotent(self): + """Test that calling close() multiple times is safe""" + shared = SharedTensor.empty((10, 10), torch.float32) + + # Multiple closes should not raise + shared.close() + shared.close() + shared.close() + + # Cleanup + shared.drop() + + def test_drop_is_idempotent(self): + """Test that calling drop() multiple times is safe""" + shared = SharedTensor.empty((10, 10), torch.float32) + handle = shared.get_handle() + shared.close() + + # Multiple drops should not raise + handle.drop() + handle.drop() + handle.drop() + + def test_proper_cleanup_prevents_leak(self): + """Test that proper close + unlink pattern doesn't leak""" + import glob + + # Get initial shared memory count + shm_before = len(glob.glob("/dev/shm/shared_tensor_*")) + + # Create and properly cleanup 10 shared tensors + for _ in range(10): + shared = SharedTensor.empty((100, 100), torch.float32) + handle = shared.get_handle() + shared.close() + handle.drop() + + # Check no leaks + shm_after = len(glob.glob("/dev/shm/shared_tensor_*")) + assert ( + shm_after == shm_before + ), f"Memory leak detected: {shm_after - shm_before} tensors leaked" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])