Skip to content

Commit 10d336a

Browse files
committed
shared tensor util
1 parent 399b20d commit 10d336a

File tree

2 files changed

+786
-0
lines changed

2 files changed

+786
-0
lines changed

src/forge/util/_shared_tensor.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import functools
8+
import uuid
9+
from multiprocessing import shared_memory
10+
from typing import Tuple, Union
11+
12+
import numpy as np
13+
import torch
14+
15+
16+
class SharedTensor:
17+
"""Wrapper class for tensors backed my shared memory"""
18+
19+
def __init__(self, tensor=None, handle=None):
20+
if tensor is not None:
21+
self._create_from_tensor(tensor)
22+
elif handle is not None:
23+
self._create_from_handle(handle)
24+
else:
25+
raise ValueError("Must provide either tensor or handle")
26+
27+
@classmethod
28+
def empty(
29+
cls,
30+
shape: Union[Tuple[int, ...], torch.Size],
31+
dtype: torch.dtype = torch.float32,
32+
):
33+
"""
34+
Create an empty tensor directly in shared memory (no copy/allocation overhead)
35+
36+
Args:
37+
shape: Shape of the tensor
38+
dtype: PyTorch dtype (supports bfloat16, float32, etc.)
39+
40+
Returns:
41+
SharedTensor instance with uninitialized data
42+
"""
43+
instance = cls.__new__(cls)
44+
instance._create_empty(shape, dtype)
45+
return instance
46+
47+
@classmethod
48+
def zeros(
49+
cls,
50+
shape: Union[Tuple[int, ...], torch.Size],
51+
dtype: torch.dtype = torch.float32,
52+
):
53+
"""
54+
Create a zero-initialized tensor in shared memory
55+
56+
Args:
57+
shape: Shape of the tensor
58+
dtype: PyTorch dtype
59+
60+
Returns:
61+
SharedTensor instance with zeros
62+
"""
63+
shared_tensor = cls.empty(shape, dtype)
64+
shared_tensor.tensor.zero_()
65+
return shared_tensor
66+
67+
@classmethod
68+
def ones(
69+
cls,
70+
shape: Union[Tuple[int, ...], torch.Size],
71+
dtype: torch.dtype = torch.float32,
72+
):
73+
"""
74+
Create a ones-initialized tensor in shared memory
75+
76+
Args:
77+
shape: Shape of the tensor
78+
dtype: PyTorch dtype
79+
80+
Returns:
81+
SharedTensor instance with ones
82+
"""
83+
shared_tensor = cls.empty(shape, dtype)
84+
shared_tensor.tensor.fill_(1)
85+
return shared_tensor
86+
87+
def _create_empty(self, shape, dtype):
88+
"""Initialize with empty tensor in shared memory"""
89+
# Store metadata
90+
self.shape = tuple(shape) if not isinstance(shape, tuple) else shape
91+
self.dtype = dtype
92+
self.dtype_str = str(dtype)
93+
94+
# Calculate size
95+
element_size = torch.tensor([], dtype=dtype).element_size()
96+
total_elements = int(np.prod(self.shape))
97+
byte_size = total_elements * element_size
98+
99+
# Create shared memory (uninitialized - fast!)
100+
shm_name = f"shared_tensor_{uuid.uuid4().hex}"
101+
self.shm = shared_memory.SharedMemory(
102+
create=True, size=byte_size, name=shm_name
103+
)
104+
self.shm_name = shm_name
105+
106+
def _create_from_tensor(self, tensor):
107+
"""Initialize from an existing tensor"""
108+
tensor = tensor.contiguous()
109+
110+
# Store metadata
111+
self.shape = tuple(tensor.shape)
112+
self.dtype = tensor.dtype
113+
self.dtype_str = str(tensor.dtype)
114+
115+
# Create shared memory
116+
byte_size = tensor.numel() * tensor.element_size()
117+
shm_name = f"shared_tensor_{uuid.uuid4().hex}"
118+
119+
self.shm = shared_memory.SharedMemory(
120+
create=True, size=byte_size, name=shm_name
121+
)
122+
self.shm_name = shm_name
123+
124+
# Copy data as raw bytes
125+
raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy()
126+
self.shm.buf[:byte_size] = raw_bytes
127+
128+
def _create_from_handle(self, handle):
129+
"""Initialize from a handle"""
130+
self.shm_name = handle["shm_name"]
131+
self.shape = handle["shape"]
132+
self.dtype_str = handle["dtype"]
133+
self.dtype = self._parse_dtype(self.dtype_str)
134+
135+
# Attach to existing shared memory
136+
self.shm = shared_memory.SharedMemory(name=self.shm_name)
137+
138+
def _create_tensor_view(self):
139+
"""Create tensor view of shared memory."""
140+
element_size = torch.tensor([], dtype=self.dtype).element_size()
141+
total_elements = int(np.prod(self.shape))
142+
byte_size = total_elements * element_size
143+
144+
# Create numpy array that shares the buffer
145+
np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self.shm.buf)
146+
# Create torch tensor from numpy (shares memory)
147+
uint8_tensor = torch.from_numpy(np_array)
148+
tensor = uint8_tensor.view(self.dtype).reshape(self.shape)
149+
150+
# Keep both the np array and the SharedTensor object alive
151+
tensor._forge_shared_tensor = self
152+
tensor._forge_np_array = np_array
153+
154+
return tensor
155+
156+
def _parse_dtype(self, dtype_str):
157+
"""Parse dtype string"""
158+
dtype_str = dtype_str.replace("torch.", "")
159+
return getattr(torch, dtype_str)
160+
161+
def get_handle(self):
162+
"""Get picklable handle"""
163+
return {"shm_name": self.shm_name, "shape": self.shape, "dtype": self.dtype_str}
164+
165+
@functools.cached_property
166+
def tensor(self):
167+
"""Get the underlying tensor"""
168+
return self._create_tensor_view()
169+
170+
def copy_from(self, source_tensor):
171+
"""
172+
Copy data from another tensor into this shared tensor
173+
Useful when you create empty tensor first, then fill it
174+
175+
Args:
176+
source_tensor: Source tensor to copy from
177+
"""
178+
if source_tensor.shape != self.shape:
179+
raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self.shape}")
180+
# Copy data
181+
self.tensor.copy_(source_tensor)
182+
183+
def clone(self):
184+
"""Create a new SharedTensor with copied data"""
185+
new_shared = SharedTensor.empty(self.shape, self.dtype)
186+
new_shared.tensor.copy_(self.tensor)
187+
return new_shared
188+
189+
def cleanup(self):
190+
"""Clean up shared memory"""
191+
try:
192+
self.shm.close()
193+
self.shm.unlink()
194+
except Exception:
195+
pass
196+
197+
def __del__(self):
198+
"""Cleanup on deletion"""
199+
if hasattr(self, "shm"):
200+
try:
201+
self.shm.close()
202+
except Exception:
203+
pass
204+
205+
def __repr__(self):
206+
return f"SharedTensor(shape={self.shape}, dtype={self.dtype}, shm_name={self.shm_name})"

0 commit comments

Comments
 (0)