Skip to content
Closed
60 changes: 57 additions & 3 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections.abc import Mapping
from copy import copy
from dataclasses import dataclass, field
from typing import Optional

import torch
import torchstore as ts
Expand Down Expand Up @@ -57,6 +58,7 @@
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -386,6 +388,24 @@ async def run(self) -> None:
if len(self.requests) == 0:
self.request_lock.notify_all()

async def _fetch_weights(self, version: int) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
t = Tracer("policy_perf/_fetch_weights")
t.start()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
shared_memory_state_dict = {}
for name in hf_param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
# TODO: preallocate in the shared memory once we have plumbing in torchstore.
shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle()
t.stop()
return shared_memory_state_dict

@endpoint
async def update_weights(self, version: int) -> None:
"""Update weights on base model from a generator version to be found in a torchstore volume.
Expand All @@ -400,6 +420,9 @@ async def update_weights(self, version: int) -> None:
>>> await trainer.push_weights()
>>> generator.update_weights(version)
"""
# TODO: currently the alloc in ts.get will block the event loop unfortunately
# potentially we need to change torchstore
fetch_task = asyncio.create_task(self._fetch_weights(version))
# Serialize updates (only one update at a time)
async with self.update_lock:
# Grab the lock to stop accepting requests and wait on pending requests
Expand Down Expand Up @@ -431,9 +454,17 @@ async def update_weights(self, version: int) -> None:
)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
# Call update_weights on every generator_worker
await self.generator_worker.update_weights.call(version=version)
# Call update_weights on every policy_worker
t = Tracer("generator_perf/_waiting_for_fetch_weights")
t.start()
fetched_weights = await fetch_task
t.stop()
await self.generator_worker.update_weights.call(
shared_memory_state_dict=fetched_weights
)
self.generator_version = version
for _, handle in fetched_weights.items():
handle.to_shared_tensor().drop()

# After updating the weights, we need to reset the KV cache
self.scheduler.reset_prefix_cache()
Expand Down Expand Up @@ -597,8 +628,31 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput:
return self.worker.execute_model(schedule)

@endpoint
async def update_weights(self, version: int) -> None:
async def update_weights(
self,
version: Optional[int] = None,
*,
shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None,
) -> None:
model = self.worker.model_runner.model
if shared_memory_state_dict is not None:
logger.info("[PolicyWorker] update weights from shared memory.")
t = Tracer(
"policy_worker_perf/update_weights_from_shared_memory", timer="gpu"
)
t.start()
loaded_weights = set()
for name, param_handle in shared_memory_state_dict.items():
param = param_handle.to_shared_tensor().tensor
loaded = model.load_weights([(name, param)])
loaded_weights.update(loaded)
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters")
t.stop()
return
if version is None:
raise ValueError(
"version must be provided if not using shared_memory_state_dict"
)
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
Expand Down
244 changes: 244 additions & 0 deletions src/forge/util/_shared_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import functools
import uuid
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Optional, Tuple, Union

import numpy as np
import torch


@dataclass
class SharedTensorHandle:
shm_name: str
shape: Tuple[int, ...]
dtype: str

def to_shared_tensor(self) -> SharedTensor:
"""
Create a SharedTensor from this handle.

Returns:
SharedTensor instance attached to the shared memory referenced by this handle
"""
return SharedTensor(handle=self)


class SharedTensor:
"""Wrapper class for tensors backed my shared memory"""

def __init__(
self,
*,
tensor: Optional[torch.Tensor] = None,
handle: Optional[SharedTensorHandle] = None,
):
if tensor is not None:
self._create_from_tensor(tensor)
elif handle is not None:
self._create_from_handle(handle)
else:
raise ValueError("Must provide either tensor or handle")

@classmethod
def empty(
cls,
shape: Union[Tuple[int, ...], torch.Size],
dtype: torch.dtype = torch.float32,
):
"""
Create an empty tensor directly in shared memory (no copy/allocation overhead)

Args:
shape: Shape of the tensor
dtype: PyTorch dtype (supports bfloat16, float32, etc.)

Returns:
SharedTensor instance with uninitialized data
"""
instance = cls.__new__(cls)
instance._create_empty(shape, dtype)
return instance

@classmethod
def zeros(
cls,
shape: Union[Tuple[int, ...], torch.Size],
dtype: torch.dtype = torch.float32,
):
"""
Create a zero-initialized tensor in shared memory

Args:
shape: Shape of the tensor
dtype: PyTorch dtype

Returns:
SharedTensor instance with zeros
"""
shared_tensor = cls.empty(shape, dtype)
shared_tensor.tensor.zero_()
return shared_tensor

@classmethod
def ones(
cls,
shape: Union[Tuple[int, ...], torch.Size],
dtype: torch.dtype = torch.float32,
):
"""
Create a ones-initialized tensor in shared memory

Args:
shape: Shape of the tensor
dtype: PyTorch dtype

Returns:
SharedTensor instance with ones
"""
shared_tensor = cls.empty(shape, dtype)
shared_tensor.tensor.fill_(1)
return shared_tensor

def _create_empty(self, shape, dtype):
"""Initialize with empty tensor in shared memory"""
# Store metadata
self._shape = tuple(shape) if not isinstance(shape, tuple) else shape
self._dtype = dtype
self._dtype_str = str(dtype)

# Calculate size
element_size = torch.tensor([], dtype=dtype).element_size()
total_elements = int(np.prod(self._shape))
byte_size = total_elements * element_size

# Create shared memory (uninitialized - fast!)
shm_name = f"shared_tensor_{uuid.uuid4().hex}"
self._shm = shared_memory.SharedMemory(
create=True, size=byte_size, name=shm_name
)
self._shm_name = shm_name

def _create_from_tensor(self, tensor):
"""Initialize from an existing tensor"""
tensor = tensor.contiguous()

# Store metadata
self._shape = tuple(tensor.shape)
self._dtype = tensor.dtype
self._dtype_str = str(tensor.dtype)

# Create shared memory
byte_size = tensor.numel() * tensor.element_size()
shm_name = f"shared_tensor_{uuid.uuid4().hex}"

self._shm = shared_memory.SharedMemory(
create=True, size=byte_size, name=shm_name
)
self._shm_name = shm_name

# Copy data as raw bytes
raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy()
self._shm.buf[:byte_size] = raw_bytes

def _create_from_handle(self, handle: SharedTensorHandle):
"""Initialize from a handle"""
self._shm_name = handle.shm_name
self._shape = handle.shape
self._dtype_str = handle.dtype
self._dtype = self._parse_dtype(self._dtype_str)

# Attach to existing shared memory
self._shm = shared_memory.SharedMemory(name=self._shm_name)

def _create_tensor_view(self):
"""Create tensor view of shared memory."""
element_size = torch.tensor([], dtype=self._dtype).element_size()
total_elements = int(np.prod(self._shape))
byte_size = total_elements * element_size

# Create numpy array that shares the buffer
np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self._shm.buf)
# Create torch tensor from numpy (shares memory)
uint8_tensor = torch.from_numpy(np_array)
tensor = uint8_tensor.view(self._dtype).reshape(self._shape)

# Keep both the np array and the SharedTensor object alive
tensor._forge_shared_tensor = self
tensor._forge_np_array = np_array

return tensor

def _parse_dtype(self, dtype_str):
"""Parse dtype string"""
dtype_str = dtype_str.replace("torch.", "")
return getattr(torch, dtype_str)

def get_handle(self):
"""Get picklable handle"""
return SharedTensorHandle(
shm_name=self._shm_name,
shape=self._shape,
dtype=self._dtype_str,
)

@functools.cached_property
def tensor(self):
"""Get the underlying tensor"""
return self._create_tensor_view()

def copy_from(self, source_tensor):
"""
Copy data from another tensor into this shared tensor
Useful when you create empty tensor first, then fill it

Args:
source_tensor: Source tensor to copy from
"""
if source_tensor.shape != self._shape:
raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self._shape}")
# Copy data
self.tensor.copy_(source_tensor)

def clone(self):
"""Create a new SharedTensor with copied data"""
new_shared = SharedTensor.empty(self._shape, self._dtype)
new_shared.tensor.copy_(self.tensor)
return new_shared

def drop(self):
"""
Release and unlink the shared memory.

This method closes the shared memory handle and removes the shared memory
segment from the system. After calling this method, the shared memory
will no longer be accessible by any process.

Note:
This should be called when the shared tensor is no longer needed.
Failing to call this method may result in shared memory leaks.
"""
try:
self._shm.close()
self._shm.unlink()
except Exception:
pass

def __del__(self):
"""Cleanup on deletion"""
if hasattr(self, "shm"):
try:
self._shm.close()
except Exception:
pass

def __repr__(self):
return f"SharedTensor(shape={self._shape}, dtype={self._dtype}, shm_name={self._shm_name})"
Loading
Loading