| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +import functools  | 
 | 8 | +import uuid  | 
 | 9 | +from multiprocessing import shared_memory  | 
 | 10 | +from typing import Tuple, Union  | 
 | 11 | + | 
 | 12 | +import numpy as np  | 
 | 13 | +import torch  | 
 | 14 | + | 
 | 15 | + | 
 | 16 | +class SharedTensor:  | 
 | 17 | +    """Wrapper class for tensors backed my shared memory"""  | 
 | 18 | + | 
 | 19 | +    def __init__(self, tensor=None, handle=None):  | 
 | 20 | +        if tensor is not None:  | 
 | 21 | +            self._create_from_tensor(tensor)  | 
 | 22 | +        elif handle is not None:  | 
 | 23 | +            self._create_from_handle(handle)  | 
 | 24 | +        else:  | 
 | 25 | +            raise ValueError("Must provide either tensor or handle")  | 
 | 26 | + | 
 | 27 | +    @classmethod  | 
 | 28 | +    def empty(  | 
 | 29 | +        cls,  | 
 | 30 | +        shape: Union[Tuple[int, ...], torch.Size],  | 
 | 31 | +        dtype: torch.dtype = torch.float32,  | 
 | 32 | +    ):  | 
 | 33 | +        """  | 
 | 34 | +        Create an empty tensor directly in shared memory (no copy/allocation overhead)  | 
 | 35 | +
  | 
 | 36 | +        Args:  | 
 | 37 | +            shape: Shape of the tensor  | 
 | 38 | +            dtype: PyTorch dtype (supports bfloat16, float32, etc.)  | 
 | 39 | +
  | 
 | 40 | +        Returns:  | 
 | 41 | +            SharedTensor instance with uninitialized data  | 
 | 42 | +        """  | 
 | 43 | +        instance = cls.__new__(cls)  | 
 | 44 | +        instance._create_empty(shape, dtype)  | 
 | 45 | +        return instance  | 
 | 46 | + | 
 | 47 | +    @classmethod  | 
 | 48 | +    def zeros(  | 
 | 49 | +        cls,  | 
 | 50 | +        shape: Union[Tuple[int, ...], torch.Size],  | 
 | 51 | +        dtype: torch.dtype = torch.float32,  | 
 | 52 | +    ):  | 
 | 53 | +        """  | 
 | 54 | +        Create a zero-initialized tensor in shared memory  | 
 | 55 | +
  | 
 | 56 | +        Args:  | 
 | 57 | +            shape: Shape of the tensor  | 
 | 58 | +            dtype: PyTorch dtype  | 
 | 59 | +
  | 
 | 60 | +        Returns:  | 
 | 61 | +            SharedTensor instance with zeros  | 
 | 62 | +        """  | 
 | 63 | +        shared_tensor = cls.empty(shape, dtype)  | 
 | 64 | +        shared_tensor.tensor.zero_()  | 
 | 65 | +        return shared_tensor  | 
 | 66 | + | 
 | 67 | +    @classmethod  | 
 | 68 | +    def ones(  | 
 | 69 | +        cls,  | 
 | 70 | +        shape: Union[Tuple[int, ...], torch.Size],  | 
 | 71 | +        dtype: torch.dtype = torch.float32,  | 
 | 72 | +    ):  | 
 | 73 | +        """  | 
 | 74 | +        Create a ones-initialized tensor in shared memory  | 
 | 75 | +
  | 
 | 76 | +        Args:  | 
 | 77 | +            shape: Shape of the tensor  | 
 | 78 | +            dtype: PyTorch dtype  | 
 | 79 | +
  | 
 | 80 | +        Returns:  | 
 | 81 | +            SharedTensor instance with ones  | 
 | 82 | +        """  | 
 | 83 | +        shared_tensor = cls.empty(shape, dtype)  | 
 | 84 | +        shared_tensor.tensor.fill_(1)  | 
 | 85 | +        return shared_tensor  | 
 | 86 | + | 
 | 87 | +    def _create_empty(self, shape, dtype):  | 
 | 88 | +        """Initialize with empty tensor in shared memory"""  | 
 | 89 | +        # Store metadata  | 
 | 90 | +        self.shape = tuple(shape) if not isinstance(shape, tuple) else shape  | 
 | 91 | +        self.dtype = dtype  | 
 | 92 | +        self.dtype_str = str(dtype)  | 
 | 93 | + | 
 | 94 | +        # Calculate size  | 
 | 95 | +        element_size = torch.tensor([], dtype=dtype).element_size()  | 
 | 96 | +        total_elements = int(np.prod(self.shape))  | 
 | 97 | +        byte_size = total_elements * element_size  | 
 | 98 | + | 
 | 99 | +        # Create shared memory (uninitialized - fast!)  | 
 | 100 | +        shm_name = f"shared_tensor_{uuid.uuid4().hex}"  | 
 | 101 | +        self.shm = shared_memory.SharedMemory(  | 
 | 102 | +            create=True, size=byte_size, name=shm_name  | 
 | 103 | +        )  | 
 | 104 | +        self.shm_name = shm_name  | 
 | 105 | + | 
 | 106 | +    def _create_from_tensor(self, tensor):  | 
 | 107 | +        """Initialize from an existing tensor"""  | 
 | 108 | +        tensor = tensor.contiguous()  | 
 | 109 | + | 
 | 110 | +        # Store metadata  | 
 | 111 | +        self.shape = tuple(tensor.shape)  | 
 | 112 | +        self.dtype = tensor.dtype  | 
 | 113 | +        self.dtype_str = str(tensor.dtype)  | 
 | 114 | + | 
 | 115 | +        # Create shared memory  | 
 | 116 | +        byte_size = tensor.numel() * tensor.element_size()  | 
 | 117 | +        shm_name = f"shared_tensor_{uuid.uuid4().hex}"  | 
 | 118 | + | 
 | 119 | +        self.shm = shared_memory.SharedMemory(  | 
 | 120 | +            create=True, size=byte_size, name=shm_name  | 
 | 121 | +        )  | 
 | 122 | +        self.shm_name = shm_name  | 
 | 123 | + | 
 | 124 | +        # Copy data as raw bytes  | 
 | 125 | +        raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy()  | 
 | 126 | +        self.shm.buf[:byte_size] = raw_bytes  | 
 | 127 | + | 
 | 128 | +    def _create_from_handle(self, handle):  | 
 | 129 | +        """Initialize from a handle"""  | 
 | 130 | +        self.shm_name = handle["shm_name"]  | 
 | 131 | +        self.shape = handle["shape"]  | 
 | 132 | +        self.dtype_str = handle["dtype"]  | 
 | 133 | +        self.dtype = self._parse_dtype(self.dtype_str)  | 
 | 134 | + | 
 | 135 | +        # Attach to existing shared memory  | 
 | 136 | +        self.shm = shared_memory.SharedMemory(name=self.shm_name)  | 
 | 137 | + | 
 | 138 | +    def _create_tensor_view(self):  | 
 | 139 | +        """Create tensor view of shared memory."""  | 
 | 140 | +        element_size = torch.tensor([], dtype=self.dtype).element_size()  | 
 | 141 | +        total_elements = int(np.prod(self.shape))  | 
 | 142 | +        byte_size = total_elements * element_size  | 
 | 143 | + | 
 | 144 | +        # Create numpy array that shares the buffer  | 
 | 145 | +        np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self.shm.buf)  | 
 | 146 | +        # Create torch tensor from numpy (shares memory)  | 
 | 147 | +        uint8_tensor = torch.from_numpy(np_array)  | 
 | 148 | +        tensor = uint8_tensor.view(self.dtype).reshape(self.shape)  | 
 | 149 | + | 
 | 150 | +        # Keep both the np array and the SharedTensor object alive  | 
 | 151 | +        tensor._forge_shared_tensor = self  | 
 | 152 | +        tensor._forge_np_array = np_array  | 
 | 153 | + | 
 | 154 | +        return tensor  | 
 | 155 | + | 
 | 156 | +    def _parse_dtype(self, dtype_str):  | 
 | 157 | +        """Parse dtype string"""  | 
 | 158 | +        dtype_str = dtype_str.replace("torch.", "")  | 
 | 159 | +        return getattr(torch, dtype_str)  | 
 | 160 | + | 
 | 161 | +    def get_handle(self):  | 
 | 162 | +        """Get picklable handle"""  | 
 | 163 | +        return {"shm_name": self.shm_name, "shape": self.shape, "dtype": self.dtype_str}  | 
 | 164 | + | 
 | 165 | +    @functools.cached_property  | 
 | 166 | +    def tensor(self):  | 
 | 167 | +        """Get the underlying tensor"""  | 
 | 168 | +        return self._create_tensor_view()  | 
 | 169 | + | 
 | 170 | +    def copy_from(self, source_tensor):  | 
 | 171 | +        """  | 
 | 172 | +        Copy data from another tensor into this shared tensor  | 
 | 173 | +        Useful when you create empty tensor first, then fill it  | 
 | 174 | +
  | 
 | 175 | +        Args:  | 
 | 176 | +            source_tensor: Source tensor to copy from  | 
 | 177 | +        """  | 
 | 178 | +        if source_tensor.shape != self.shape:  | 
 | 179 | +            raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self.shape}")  | 
 | 180 | +        # Copy data  | 
 | 181 | +        self.tensor.copy_(source_tensor)  | 
 | 182 | + | 
 | 183 | +    def clone(self):  | 
 | 184 | +        """Create a new SharedTensor with copied data"""  | 
 | 185 | +        new_shared = SharedTensor.empty(self.shape, self.dtype)  | 
 | 186 | +        new_shared.tensor.copy_(self.tensor)  | 
 | 187 | +        return new_shared  | 
 | 188 | + | 
 | 189 | +    def cleanup(self):  | 
 | 190 | +        """Clean up shared memory"""  | 
 | 191 | +        try:  | 
 | 192 | +            self.shm.close()  | 
 | 193 | +            self.shm.unlink()  | 
 | 194 | +        except Exception:  | 
 | 195 | +            pass  | 
 | 196 | + | 
 | 197 | +    def __del__(self):  | 
 | 198 | +        """Cleanup on deletion"""  | 
 | 199 | +        if hasattr(self, "shm"):  | 
 | 200 | +            try:  | 
 | 201 | +                self.shm.close()  | 
 | 202 | +            except Exception:  | 
 | 203 | +                pass  | 
 | 204 | + | 
 | 205 | +    def __repr__(self):  | 
 | 206 | +        return f"SharedTensor(shape={self.shape}, dtype={self.dtype}, shm_name={self.shm_name})"  | 
0 commit comments