diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 8f8cf8fc7..adbc6f3dc 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -9,10 +9,12 @@ import asyncio import logging import os +import queue import sys from collections.abc import Mapping from copy import copy from dataclasses import dataclass, field +from typing import Optional import torch import torchstore as ts @@ -49,7 +51,12 @@ load_tensor_from_dcp, ) -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +from forge.controller import ( + ForgeActor, + get_proc_mesh, + host_mesh_from_proc, + stop_proc_mesh, +) from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA @@ -57,6 +64,7 @@ from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig +from forge.util._shared_tensor import SharedTensor, SharedTensorHandle logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -143,17 +151,20 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] ) worker_procs = await get_proc_mesh(process_config=process_config) - # TODO - issues/144 we will want to ensure colocation with workers - # We're currently locating the Generator on the local host proc mesh - # vLLM initialization without setting env variables at proc_mesh creation - # level leads to issues. - # Once we can create multiple proc meshes on a host mesh, we can ensure - # host colocation + # Grab a single host from the workers... + host_mesh = await host_mesh_from_proc(worker_procs) + singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()} + host_mesh = host_mesh.slice(**singleton_slice) generator_proc_config = copy(process_config) generator_proc_config.procs = 1 - generator_proc_config.hosts = None generator_proc_config.with_gpus = False - generator_proc = await get_proc_mesh(process_config=generator_proc_config) + + # By passing in the host_mesh here, we will get a new proc + # spawned on the provided host_mesh. Since that host mesh is + # taken from the policy_proc, this ensures colocation. + generator_proc = await get_proc_mesh( + process_config=generator_proc_config, host_mesh=host_mesh + ) if isinstance(engine_args, Mapping): engine_args = EngineArgs(**engine_args) @@ -204,6 +215,11 @@ async def setup(self): self.request_lock = asyncio.Condition() # Guard for accepting_requests self.update_lock = asyncio.Condition() # Guard for updating requests + # Shared memory allocated for weight updates + self.cached_state_dict_allocs: queue.Queue[ + dict[str, SharedTensorHandle] + ] = queue.Queue(maxsize=2) + vllm_config: VllmConfig = self.engine_args.create_engine_config( UsageContext.LLM_CLASS ) @@ -244,6 +260,59 @@ def _start_processing(self): if self._run_task is None or self._run_task.done(): self._run_task = asyncio.create_task(self.run()) + async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): + for handle in state_dict.values(): + handle.drop() + + async def _cleanup_shared_memory(self): + """Cleanup shared memory allocated for weight updates.""" + while not self.cached_state_dict_allocs.empty(): + try: + state_dict = self.cached_state_dict_allocs.get_nowait() + await self._drop_shared_memory(state_dict) + except queue.Empty: + logger.info( + "Cached state dict alloc queue is empty. No state dict to drop." + ) + + async def _fetch_weights( + self, + version: int, + *, + pre_allocated: Optional[dict[str, SharedTensorHandle]] = None, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("generator_perf/_fetch_weights") + t.start() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + shared_memory_state_dict = {} + if pre_allocated is not None: + logger.info( + "[Generator] fetching weights from torchstore to shared memory. Using pre allocated shared memory." + ) + shared_memory_state_dict = pre_allocated + assert set(shared_memory_state_dict.keys()) == set( + hf_param_names + ), "The pre_allocated dict must have the same keys as hf_param_names" + for name, handle in shared_memory_state_dict.items(): + param_key = get_param_key(version, name) + param = handle.to_shared_tensor().tensor + await ts.get(param_key, inplace_tensor=param) + else: + logger.info( + "[Generator] fetching weights from torchstore to shared memory." + ) + for name in hf_param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle() + t.stop() + return shared_memory_state_dict + @endpoint async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt @@ -400,6 +469,14 @@ async def update_weights(self, version: int) -> None: >>> await trainer.push_weights() >>> generator.update_weights(version) """ + logger.info(f"[Generator] Fetching weights for v{version} to shared memory") + try: + pre_allocated = self.cached_state_dict_allocs.get_nowait() + except queue.Empty: + pre_allocated = None + fetch_task = asyncio.create_task( + self._fetch_weights(version, pre_allocated=pre_allocated) + ) # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -431,8 +508,27 @@ async def update_weights(self, version: int) -> None: ) logger.debug(f"Starting weight update on {self.__class__.__name__}") - # Call update_weights on every generator_worker - await self.generator_worker.update_weights.call(version=version) + if not self.use_dcp: + # TODO: currently the alloc in ts.get will block the event loop unfortunately + # potentially we need to change torchstore + # We have to do this because Monarch future is not directly compatible with asyncio + t = Tracer("generator_perf/waiting_for_fetch_weights") + t.start() + fetched_weights = await fetch_task + t.stop() + # Call update_weights on every policy_worker + await self.generator_worker.update_weights.call( + shared_memory_state_dict=fetched_weights + ) + try: + self.cached_state_dict_allocs.put_nowait(fetched_weights) + except queue.Full: + logger.info( + "Cached state dict alloc queue is full. Dropping allocated state dict." + ) + await self._drop_shared_memory(fetched_weights) + else: + await self.generator_worker.update_weights.call(version=version) self.generator_version = version # After updating the weights, we need to reset the KV cache @@ -504,6 +600,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] # TODO - may want to expand stop to gracefully respond to # ongoing requests. await actor.stop.call() + await actor._cleanup_shared_memory.call() await stop_proc_mesh(actor._worker_procs) await stop_proc_mesh(actor._generator_proc) @@ -597,13 +694,37 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: return self.worker.execute_model(schedule) @endpoint - async def update_weights(self, version: int) -> None: + async def update_weights( + self, + version: Optional[int] = None, + *, + shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, + ) -> None: model = self.worker.model_runner.model + if shared_memory_state_dict is not None: + logger.info("[PolicyWorker] update weights from shared memory.") + t = Tracer( + "generator_worker_perf/update_weights_from_shared_memory", timer="gpu" + ) + t.start() + loaded_weights = set() + for name, param_handle in shared_memory_state_dict.items(): + param = param_handle.to_shared_tensor().tensor + loaded = model.load_weights([(name, param)]) + loaded_weights.update(loaded) + logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") + t.stop() + return + if version is None: + raise ValueError( + "version must be provided if not using shared_memory_state_dict" + ) + # If shared memory is not provided, we assume we are using DCP prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) loaded_weights = set() - t = Tracer("worker_perf/update_weights", timer="gpu") + t = Tracer("generator_worker_perf/update_weights", timer="gpu") t.start() # Entire state dict is stored in a single DCP handle if dcp_whole_state_dict_key in matching_keys: @@ -614,16 +735,8 @@ async def update_weights(self, version: int) -> None: loaded = model.load_weights([(name, param)]) del param loaded_weights.update(loaded) - else: # Load each parameter from torchstore directly without DCP - hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) + else: + raise RuntimeError("No DCP handle found for the given version") t.stop() @endpoint diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index dd85b3c82..23ab870eb 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import logging -import math import os import shutil @@ -18,7 +17,7 @@ import torch.distributed.checkpoint as dcp import torchstore as ts -from monarch.actor import current_rank, current_size, endpoint +from monarch.actor import endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict from torchtitan.config.job_config import ( @@ -163,19 +162,7 @@ def __post_init__(self): self.step = 1 # fragile contract. self.num_training_steps = self.training.steps self.gradient_accumulation_steps = 1 - self.rank = current_rank().rank - self.size = math.prod(current_size().values()) - env = { - "RANK": str(self.rank), - "LOCAL_RANK": str(self.rank), - "LOCAL_WORLD_SIZE": str(self.size), - "GROUP_RANK": str(self.size), - "GROUP_WORLD_SIZE": str(self.size), - "ROLE_RANK": str(self.rank), - "ROLE_WORLD_SIZE": str(self.size), - "ROLE_NAME": "rank", - "WORLD_SIZE": str(self.size), "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", } os.environ.update(env) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 6d470a87f..6ebdb2117 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -20,8 +20,9 @@ from monarch.tools import commands -from forge.controller.launcher import BaseLauncher, get_launcher +from monarch.utils import setup_env_for_distributed +from forge.controller.launcher import BaseLauncher, get_launcher from forge.env import all_env_vars, FORGE_DISABLE_METRICS from forge.types import ProcessConfig, ProvisionerConfig @@ -283,6 +284,13 @@ def bootstrap(env: dict[str, str]): bootstrap=functools.partial(bootstrap, env=env_vars), ) + # Set up environment variables for PyTorch distributed... + await setup_env_for_distributed( + procs, + master_addr=addr, + master_port=port, + ) + if is_remote: await self.launcher.remote_setup(procs) diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py new file mode 100644 index 000000000..7839fe2d0 --- /dev/null +++ b/src/forge/util/_shared_tensor.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import functools +import uuid +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Optional, Tuple, Union + +import numpy as np +import torch + + +@dataclass +class SharedTensorHandle: + shm_name: str + shape: Tuple[int, ...] + dtype: str + + def to_shared_tensor(self) -> SharedTensor: + """ + Create a SharedTensor from this handle. + + Returns: + SharedTensor instance attached to the shared memory referenced by this handle + """ + return SharedTensor(handle=self) + + +class SharedTensor: + """Wrapper class for tensors backed my shared memory""" + + def __init__( + self, + *, + tensor: Optional[torch.Tensor] = None, + handle: Optional[SharedTensorHandle] = None, + ): + if tensor is not None: + self._create_from_tensor(tensor) + elif handle is not None: + self._create_from_handle(handle) + else: + raise ValueError("Must provide either tensor or handle") + + @classmethod + def empty( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create an empty tensor directly in shared memory (no copy/allocation overhead) + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype (supports bfloat16, float32, etc.) + + Returns: + SharedTensor instance with uninitialized data + """ + instance = cls.__new__(cls) + instance._create_empty(shape, dtype) + return instance + + @classmethod + def zeros( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a zero-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with zeros + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.zero_() + return shared_tensor + + @classmethod + def ones( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a ones-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with ones + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.fill_(1) + return shared_tensor + + def _create_empty(self, shape, dtype): + """Initialize with empty tensor in shared memory""" + # Store metadata + self._shape = tuple(shape) if not isinstance(shape, tuple) else shape + self._dtype = dtype + self._dtype_str = str(dtype) + + # Calculate size + element_size = torch.tensor([], dtype=dtype).element_size() + total_elements = int(np.prod(self._shape)) + byte_size = total_elements * element_size + + # Create shared memory (uninitialized - fast!) + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + self._shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self._shm_name = shm_name + + def _create_from_tensor(self, tensor): + """Initialize from an existing tensor""" + tensor = tensor.contiguous() + + # Store metadata + self._shape = tuple(tensor.shape) + self._dtype = tensor.dtype + self._dtype_str = str(tensor.dtype) + + # Create shared memory + byte_size = tensor.numel() * tensor.element_size() + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + + self._shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self._shm_name = shm_name + + # Copy data as raw bytes + raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy() + self._shm.buf[:byte_size] = raw_bytes + + def _create_from_handle(self, handle: SharedTensorHandle): + """Initialize from a handle""" + self._shm_name = handle.shm_name + self._shape = handle.shape + self._dtype_str = handle.dtype + self._dtype = self._parse_dtype(self._dtype_str) + + # Attach to existing shared memory\ + self._shm = shared_memory.SharedMemory(name=self._shm_name) + + def _create_tensor_view(self): + """Create tensor view of shared memory.""" + element_size = torch.tensor([], dtype=self._dtype).element_size() + total_elements = int(np.prod(self._shape)) + byte_size = total_elements * element_size + + # Create numpy array that shares the buffer + np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self._shm.buf) + # Create torch tensor from numpy (shares memory) + uint8_tensor = torch.from_numpy(np_array) + tensor = uint8_tensor.view(self._dtype).reshape(self._shape) + + # Keep both the np array and the SharedTensor object alive + tensor._forge_shared_tensor = self + tensor._forge_np_array = np_array + + return tensor + + def _parse_dtype(self, dtype_str): + """Parse dtype string""" + dtype_str = dtype_str.replace("torch.", "") + return getattr(torch, dtype_str) + + def get_handle(self): + """Get picklable handle""" + return SharedTensorHandle( + shm_name=self._shm_name, + shape=self._shape, + dtype=self._dtype_str, + ) + + @functools.cached_property + def tensor(self): + """Get the underlying tensor""" + return self._create_tensor_view() + + def copy_from(self, source_tensor): + """ + Copy data from another tensor into this shared tensor + Useful when you create empty tensor first, then fill it + + Args: + source_tensor: Source tensor to copy from + """ + if source_tensor.shape != self._shape: + raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self._shape}") + # Copy data + self.tensor.copy_(source_tensor) + + def clone(self): + """Create a new SharedTensor with copied data""" + new_shared = SharedTensor.empty(self._shape, self._dtype) + new_shared.tensor.copy_(self.tensor) + return new_shared + + def drop(self): + """ + Release and unlink the shared memory. + + This method closes the shared memory handle and removes the shared memory + segment from the system. After calling this method, the shared memory + will no longer be accessible by any process. + + Note: + This should be called when the shared tensor is no longer needed. + Failing to call this method may result in shared memory leaks. + """ + try: + self._shm.close() + self._shm.unlink() + except Exception: + pass + + def __del__(self): + """Cleanup on deletion""" + if hasattr(self, "shm"): + try: + self._shm.close() + except Exception: + pass + + def __repr__(self): + return f"SharedTensor(shape={self._shape}, dtype={self._dtype}, shm_name={self._shm_name})" diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py new file mode 100644 index 000000000..094567449 --- /dev/null +++ b/tests/unit_tests/util/test_shared_tensor.py @@ -0,0 +1,692 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +import time +from multiprocessing import Process, Queue + +import pytest +import torch + +# Assuming SharedTensor is in shared_tensor.py +from forge.util._shared_tensor import SharedTensor + + +class TestSharedTensorCreation: + """Test tensor creation methods""" + + def test_empty_creation(self): + """Test creating empty tensor""" + shape = (100, 200) + dtype = torch.float32 + + shared = SharedTensor.empty(shape, dtype) + + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype + + shared.drop() + + def test_empty_with_bfloat16(self): + """Test creating empty bfloat16 tensor""" + shape = (50, 50) + shared = SharedTensor.empty(shape, torch.bfloat16) + + assert shared.tensor.dtype == torch.bfloat16 + assert shared.tensor.dtype == torch.bfloat16 + + shared.drop() + + def test_zeros_creation(self): + """Test creating zero-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.zeros(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 0) + assert tensor.sum().item() == 0.0 + + shared.drop() + + def test_ones_creation(self): + """Test creating ones-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.ones(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 1) + assert tensor.sum().item() == 200.0 + + shared.drop() + + def test_from_tensor_creation(self): + """Test creating from existing tensor""" + original = torch.randn(50, 50) + shared = SharedTensor(tensor=original) + + assert shared.tensor.shape == original.shape + assert shared.tensor.dtype == original.dtype + assert torch.allclose(shared.tensor, original) + + shared.drop() + + def test_from_handle_creation(self): + """Test creating from handle""" + # Create original + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + # Get handle + handle = original.get_handle() + + # Create from handle + reconstructed = SharedTensor(handle=handle) + + assert torch.all(reconstructed.tensor == 5.0) + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype + + original.drop() + + def test_creation_requires_argument(self): + """Test that creation without arguments raises error""" + with pytest.raises(ValueError, match="Must provide either tensor or handle"): + SharedTensor() + + @pytest.mark.parametrize( + "shape", + [ + (10,), + (10, 20), + (5, 10, 15), + (2, 3, 4, 5), + ], + ) + def test_various_shapes(self, shape): + """Test creation with various shapes""" + shared = SharedTensor.empty(shape, torch.float32) + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.shape == torch.Size(shape) + shared.drop() + + +class TestSharedTensorDtypes: + """Test all supported dtypes""" + + @pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + ], + ) + def test_all_dtypes(self, dtype): + """Test that all dtypes work correctly""" + shape = (10, 10) + shared = SharedTensor.empty(shape, dtype) + + assert shared.tensor.dtype == dtype + assert shared.tensor.dtype == dtype + + # Test that we can write to it + if dtype == torch.bool: + shared.tensor.fill_(True) + elif dtype in [torch.int32, torch.int64, torch.int16, torch.int8, torch.uint8]: + shared.tensor.fill_(42) + else: + shared.tensor.fill_(3.14) + + shared.drop() + + def test_dtype_conversion_in_handle(self): + """Test dtype is preserved through handle""" + for dtype in [torch.float32, torch.bfloat16, torch.int64]: + shared1 = SharedTensor.empty((5, 5), dtype) + handle = shared1.get_handle() + + shared2 = SharedTensor(handle=handle) + assert shared2.tensor.dtype == dtype + + shared1.drop() + + +class TestSharedTensorOperations: + """Test tensor operations""" + + def test_copy_from(self): + """Test copying data from another tensor""" + source = torch.randn(20, 30) + shared = SharedTensor.empty((20, 30), torch.float32) + + shared.copy_from(source) + + assert torch.allclose(shared.tensor, source) + shared.drop() + + def test_copy_from_shape_mismatch(self): + """Test copy_from raises error on shape mismatch""" + source = torch.randn(10, 10) + shared = SharedTensor.empty((20, 20), torch.float32) + + with pytest.raises(ValueError, match="Shape mismatch"): + shared.copy_from(source) + + shared.drop() + + def test_clone(self): + """Test cloning creates independent copy""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + cloned = original.clone() + + # Verify data is same + assert torch.all(cloned.tensor == 5.0) + + # Verify they're independent + original.tensor.fill_(10.0) + assert torch.all(cloned.tensor == 5.0) + assert torch.all(original.tensor == 10.0) + + original.drop() + cloned.drop() + + def test_tensor_modifications(self): + """Test that modifications to tensor are reflected""" + shared = SharedTensor.zeros((10, 10), torch.float32) + tensor = shared.tensor + + tensor[0, 0] = 99.0 + tensor[5:, :] = 42.0 + + # Get tensor again and verify changes persist + tensor2 = shared.tensor + assert tensor2[0, 0].item() == 99.0 + assert torch.all(tensor2[5:, :] == 42.0) + + shared.drop() + + def test_inplace_operations(self): + """Test in-place operations work""" + shared = SharedTensor.empty((100, 100), torch.float32) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.mean().item() + + tensor.add_(5.0) + new_mean = tensor.mean().item() + + assert abs(new_mean - (mean + 5.0)) < 0.1 + + shared.drop() + + +class TestSharedTensorSerialization: + """Test pickling and handle serialization""" + + def test_handle_is_picklable(self): + """Test that handle can be pickled""" + shared = SharedTensor.empty((10, 10), torch.float32) + handle = shared.get_handle() + + # Pickle and unpickle + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + assert unpickled_handle == handle + + shared.drop() + + def test_handle_small_size(self): + """Test that handle is small (efficient for RPC)""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + pickled = pickle.dumps(handle) + + # Handle should be < 1KB even for huge tensors + assert len(pickled) < 1024 + + shared.drop() + + def test_data_integrity_after_pickle(self): + """Test data is preserved through handle pickling""" + # Create and fill tensor + shared1 = SharedTensor.empty((50, 50), torch.bfloat16) + shared1.tensor.normal_(0, 1) + original_data = shared1.tensor.clone() + + # Pickle handle + handle = shared1.get_handle() + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + # Reconstruct + shared2 = SharedTensor(handle=unpickled_handle) + + # Verify data is same + assert torch.allclose(shared2.tensor.float(), original_data.float(), rtol=1e-3) + + shared1.drop() + + +class TestSharedTensorMemory: + """Test memory management and cleanup""" + + def test_drop(self): + """Test drop removes shared memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + shm_name = shared._shm_name + + # Verify shared memory exists + tensor = shared.tensor + tensor.fill_(5.0) + + # Drop shared memory + shared.drop() + + # Trying to attach should fail + from multiprocessing import shared_memory + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=shm_name) + + def test_multiple_views_same_memory(self): + """Test multiple tensor views point to same memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + + tensor1 = shared.tensor + tensor1.fill_(5.0) + + tensor2 = shared.tensor + assert torch.all(tensor2 == 5.0) + + # Modify through tensor2 + tensor2.fill_(10.0) + + # Verify tensor1 sees the change + assert torch.all(tensor1 == 10.0) + + shared.drop() + + def test_handle_reconstruction_shares_memory(self): + """Test that handle reconstruction shares same memory""" + shared1 = SharedTensor.empty((20, 20), torch.float32) + shared1.tensor.fill_(7.0) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + # Modify through shared2 + shared2.tensor.fill_(14.0) + + # Verify shared1 sees the change + assert torch.all(shared1.tensor == 14.0) + + shared1.drop() + + +class TestSharedTensorEdgeCases: + """Test edge cases and error conditions""" + + def test_empty_shape(self): + """Test scalar tensor (empty shape)""" + shared = SharedTensor.ones((), torch.float32) + assert shared.tensor.shape == () + assert shared.tensor.numel() == 1 + assert torch.allclose( + shared.tensor, + torch.ones( + (), + ), + ) + shared.drop() + + def test_single_element_tensor(self): + """Test 1-element tensor""" + shared = SharedTensor.empty((1,), torch.float32) + shared.tensor.fill_(42.0) + assert shared.tensor.item() == 42.0 + shared.drop() + + def test_large_tensor(self): + """Test large tensor (1GB)""" + # 1GB tensor: 250M float32 elements + shape = (250_000_000,) + shared = SharedTensor.empty(shape, torch.float32) + + assert shared.tensor.shape == shape + assert shared.tensor.numel() == 250_000_000 + + shared.drop() + + def test_non_contiguous_tensor_conversion(self): + """Test that non-contiguous tensors are handled""" + # Create non-contiguous tensor + original = torch.randn(10, 10).t() # Transpose makes it non-contiguous + assert not original.is_contiguous() + + # Should work (internally makes contiguous) + shared = SharedTensor(tensor=original) + + # Result should match + assert torch.allclose(shared.tensor, original) + + shared.drop() + + def test_repr(self): + """Test string representation""" + shared = SharedTensor.empty((10, 20), torch.float32) + repr_str = repr(shared) + + assert "SharedTensor" in repr_str + assert "10, 20" in repr_str + assert "float32" in repr_str + assert shared._shm_name in repr_str + + shared.drop() + + +class TestSharedTensorMultiprocess: + """Test multiprocess scenarios""" + + def test_multiprocess_read(self): + """Test reading shared tensor from another process""" + + def reader_process(handle_dict, result_queue): + shared = SharedTensor(handle=handle_dict) + tensor = shared.tensor + result_queue.put(tensor.sum().item()) + + # Create shared tensor in main process + shared = SharedTensor.empty((100, 100), torch.float32) + shared.tensor.fill_(5.0) + + # Read from child process + result_queue = Queue() + handle = shared.get_handle() + + p = Process(target=reader_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 5.0 * 100 * 100 + + assert abs(result - expected) < 1e-5 + + shared.drop() + + def test_multiprocess_write(self): + """Test writing to shared tensor from another process""" + + def writer_process(handle_dict, value): + shared = SharedTensor(handle=handle_dict) + shared.tensor.fill_(value) + + # Create empty shared tensor + shared = SharedTensor.empty((50, 50), torch.float32) + shared.tensor.zero_() + + # Write from child process + handle = shared.get_handle() + + p = Process(target=writer_process, args=(handle, 42.0)) + p.start() + p.join() + + # Verify in main process + assert torch.all(shared.tensor == 42.0) + + shared.drop() + + def test_multiprocess_bidirectional(self): + """Test bidirectional communication""" + + def worker_process(input_handle, output_handle): + input_tensor = SharedTensor(handle=input_handle).tensor + output_tensor = SharedTensor(handle=output_handle).tensor + + # Compute: output = input * 2 + output_tensor.copy_(input_tensor * 2) + + # Create input and output tensors + input_shared = SharedTensor.empty((100, 100), torch.float32) + input_shared.tensor.normal_(0, 1) + input_data = input_shared.tensor.clone() + + output_shared = SharedTensor.empty((100, 100), torch.float32) + + # Process in child + p = Process( + target=worker_process, + args=(input_shared.get_handle(), output_shared.get_handle()), + ) + p.start() + p.join() + + # Verify result + expected = input_data * 2 + assert torch.allclose( + output_shared.tensor, expected + ), "output: {}, expected: {}".format(output_shared.tensor, expected) + + input_shared.drop() + output_shared.drop() + + +class TestSharedTensorPerformance: + """Performance-related tests""" + + def test_empty_faster_than_from_tensor(self): + """Test that empty() is faster than from tensor""" + shape = (1000, 1000) + + # Time empty creation + start = time.time() + for _ in range(10): + shared = SharedTensor.empty(shape, torch.float32) + shared.drop() + empty_time = time.time() - start + + # Time from_tensor creation + start = time.time() + for _ in range(10): + tensor = torch.randn(shape) + shared = SharedTensor(tensor=tensor) + shared.drop() + from_tensor_time = time.time() - start + + # empty() should be faster (no data copying) + assert empty_time < from_tensor_time + + def test_handle_serialization_fast(self): + """Test that handle serialization is fast""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + start = time.time() + for _ in range(1000): + pickled = pickle.dumps(handle) + unpickled = pickle.loads(pickled) + elapsed = time.time() - start + + # Should be able to do 1000 round trips in < 0.1 seconds + assert elapsed < 0.1 + + shared.drop() + + +class TestSharedTensorHandleToSharedTensor: + """Test SharedTensorHandle.to_shared_tensor() method""" + + def test_to_shared_tensor_basic(self): + """Test basic creation of SharedTensor from handle using to_shared_tensor method""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(7.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.all(reconstructed.tensor == 7.0) + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype + + original.drop() + + def test_to_shared_tensor_preserves_data(self): + """Test that to_shared_tensor preserves original data""" + original = SharedTensor.empty((20, 30), torch.float32) + original.tensor.normal_(0, 1) + original_data = original.tensor.clone() + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.allclose(reconstructed.tensor, original_data) + + original.drop() + + def test_to_shared_tensor_shares_memory(self): + """Test that to_shared_tensor shares memory with original""" + original = SharedTensor.empty((15, 15), torch.float32) + original.tensor.fill_(5.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + reconstructed.tensor.fill_(10.0) + + assert torch.all(original.tensor == 10.0) + + original.drop() + + def test_to_shared_tensor_with_various_dtypes(self): + """Test to_shared_tensor works with different data types""" + for dtype in [torch.float32, torch.float64, torch.bfloat16, torch.int32]: + original = SharedTensor.empty((5, 5), dtype) + if ( + dtype == torch.bfloat16 + or dtype == torch.float32 + or dtype == torch.float64 + ): + original.tensor.normal_(0, 1) + else: + original.tensor.fill_(42) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert reconstructed.tensor.dtype == dtype + if dtype == torch.bfloat16: + assert torch.allclose( + reconstructed.tensor.float(), original.tensor.float(), rtol=1e-3 + ) + else: + assert torch.allclose(reconstructed.tensor, original.tensor) + + original.drop() + + def test_to_shared_tensor_multiprocess(self): + """Test to_shared_tensor in multiprocess scenario""" + + def worker_process(handle, result_queue): + shared = handle.to_shared_tensor() + result_queue.put(shared.tensor.sum().item()) + + original = SharedTensor.empty((50, 50), torch.float32) + original.tensor.fill_(3.0) + + handle = original.get_handle() + result_queue = Queue() + + p = Process(target=worker_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 3.0 * 50 * 50 + + assert abs(result - expected) < 1e-5 + + original.drop() + + def test_to_shared_tensor_equivalent_to_constructor(self): + """Test that handle.to_shared_tensor() is equivalent to SharedTensor(handle=handle)""" + original = SharedTensor.empty((25, 25), torch.float32) + original.tensor.normal_(0, 1) + + handle = original.get_handle() + + via_method = handle.to_shared_tensor() + via_constructor = SharedTensor(handle=handle) + + assert torch.allclose(via_method.tensor, via_constructor.tensor) + assert via_method.tensor.shape == via_constructor.tensor.shape + assert via_method.tensor.dtype == via_constructor.tensor.dtype + + original.drop() + + +class TestSharedTensorBfloat16: + """Specific tests for bfloat16 support""" + + def test_bfloat16_creation(self): + """Test bfloat16 tensor creation""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + assert shared.tensor.dtype == torch.bfloat16 + shared.drop() + + def test_bfloat16_from_tensor(self): + """Test creating shared tensor from bfloat16 tensor""" + original = torch.randn(50, 50, dtype=torch.bfloat16) + shared = SharedTensor(tensor=original) + + assert shared.tensor.dtype == torch.bfloat16 + assert torch.allclose(shared.tensor.float(), original.float(), rtol=1e-3) + + shared.drop() + + def test_bfloat16_handle_preservation(self): + """Test bfloat16 dtype preserved through handle""" + shared1 = SharedTensor.empty((20, 20), torch.bfloat16) + shared1.tensor.normal_(0, 1) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + assert shared2.tensor.dtype == torch.bfloat16 + assert torch.allclose(shared1.tensor.float(), shared2.tensor.float(), rtol=1e-3) + + shared1.drop() + + def test_bfloat16_operations(self): + """Test operations on bfloat16 tensors""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.float().mean().item() + + # Mean should be close to 0 + assert abs(mean) < 0.1 + + shared.drop() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])