Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
10d336a
shared tensor util
casteryh Oct 14, 2025
68395f9
refactor, fix test
casteryh Oct 14, 2025
3789d49
set up titan distributed through monarch utils, colocate policy with …
allenwang28 Oct 14, 2025
23c8fcb
add SharedTensorHandle class
casteryh Oct 14, 2025
f7ed526
merge
allenwang28 Oct 14, 2025
15d2784
shared memory weight loading
casteryh Oct 14, 2025
c7ca738
oopsie
casteryh Oct 14, 2025
cbd529e
end -> stop
casteryh Oct 14, 2025
20e162f
typo
casteryh Oct 14, 2025
eadf3a5
make policy_version optional
casteryh Oct 14, 2025
2b675f8
fix
casteryh Oct 14, 2025
8b488b7
no leak
casteryh Oct 14, 2025
5e7528c
disable dcp in 8b
casteryh Oct 14, 2025
2b23e50
undo the colocation
allenwang28 Oct 14, 2025
4b850ad
debug info
casteryh Oct 14, 2025
f7cbcb4
typo
casteryh Oct 14, 2025
09836f3
refactor
casteryh Oct 14, 2025
e5f984e
Merge branch 'main' into yhu/shared-tensor
casteryh Oct 14, 2025
9fa7395
temp: reduce num_replicas to 2
casteryh Oct 14, 2025
7744f4c
fix bad merge
casteryh Oct 14, 2025
8970ff4
revert to 4
casteryh Oct 14, 2025
48821de
move _fetch_weights to policy worker
casteryh Oct 15, 2025
2798f2e
typo
casteryh Oct 15, 2025
ddf8d26
endpoint
casteryh Oct 15, 2025
71b89c1
clean up
casteryh Oct 15, 2025
1971a4f
fix
casteryh Oct 15, 2025
dc301aa
fix
casteryh Oct 15, 2025
c879753
rearrange
casteryh Oct 15, 2025
c462911
log
casteryh Oct 15, 2025
571750f
vllm colocation works
allenwang28 Oct 15, 2025
a155a4c
caching allocation for weight updates
casteryh Oct 15, 2025
c2ab4e1
Merge branch 'titan_setup' into yhu/shared-tensor
casteryh Oct 15, 2025
9fa3297
test
casteryh Oct 15, 2025
0e8e183
wqMerge branch 'main' into yhu/shared-tensor
casteryh Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 136 additions & 23 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import asyncio
import logging
import os
import queue
import sys
from collections.abc import Mapping
from copy import copy
from dataclasses import dataclass, field
from typing import Optional

import torch
import torchstore as ts
Expand Down Expand Up @@ -49,14 +51,20 @@
load_tensor_from_dcp,
)

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.controller import (
ForgeActor,
get_proc_mesh,
host_mesh_from_proc,
stop_proc_mesh,
)
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt
from forge.env import TORCHSTORE_USE_RDMA
from forge.interfaces import Policy as GeneratorInterface
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -143,17 +151,20 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
)
worker_procs = await get_proc_mesh(process_config=process_config)

# TODO - issues/144 we will want to ensure colocation with workers
# We're currently locating the Generator on the local host proc mesh
# vLLM initialization without setting env variables at proc_mesh creation
# level leads to issues.
# Once we can create multiple proc meshes on a host mesh, we can ensure
# host colocation
# Grab a single host from the workers...
host_mesh = await host_mesh_from_proc(worker_procs)
singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
host_mesh = host_mesh.slice(**singleton_slice)
generator_proc_config = copy(process_config)
generator_proc_config.procs = 1
generator_proc_config.hosts = None
generator_proc_config.with_gpus = False
generator_proc = await get_proc_mesh(process_config=generator_proc_config)

# By passing in the host_mesh here, we will get a new proc
# spawned on the provided host_mesh. Since that host mesh is
# taken from the policy_proc, this ensures colocation.
generator_proc = await get_proc_mesh(
process_config=generator_proc_config, host_mesh=host_mesh
)

if isinstance(engine_args, Mapping):
engine_args = EngineArgs(**engine_args)
Expand Down Expand Up @@ -204,6 +215,11 @@ async def setup(self):
self.request_lock = asyncio.Condition() # Guard for accepting_requests
self.update_lock = asyncio.Condition() # Guard for updating requests

# Shared memory allocated for weight updates
self.cached_state_dict_allocs: queue.Queue[
dict[str, SharedTensorHandle]
] = queue.Queue(maxsize=2)

vllm_config: VllmConfig = self.engine_args.create_engine_config(
UsageContext.LLM_CLASS
)
Expand Down Expand Up @@ -244,6 +260,59 @@ def _start_processing(self):
if self._run_task is None or self._run_task.done():
self._run_task = asyncio.create_task(self.run())

async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]):
for handle in state_dict.values():
handle.drop()

async def _cleanup_shared_memory(self):
"""Cleanup shared memory allocated for weight updates."""
while not self.cached_state_dict_allocs.empty():
try:
state_dict = self.cached_state_dict_allocs.get_nowait()
await self._drop_shared_memory(state_dict)
except queue.Empty:
logger.info(
"Cached state dict alloc queue is empty. No state dict to drop."
)

async def _fetch_weights(
self,
version: int,
*,
pre_allocated: Optional[dict[str, SharedTensorHandle]] = None,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
t = Tracer("generator_perf/_fetch_weights")
t.start()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
shared_memory_state_dict = {}
if pre_allocated is not None:
logger.info(
"[Generator] fetching weights from torchstore to shared memory. Using pre allocated shared memory."
)
shared_memory_state_dict = pre_allocated
assert set(shared_memory_state_dict.keys()) == set(
hf_param_names
), "The pre_allocated dict must have the same keys as hf_param_names"
for name, handle in shared_memory_state_dict.items():
param_key = get_param_key(version, name)
param = handle.to_shared_tensor().tensor
await ts.get(param_key, inplace_tensor=param)
else:
logger.info(
"[Generator] fetching weights from torchstore to shared memory."
)
for name in hf_param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
shared_memory_state_dict[name] = SharedTensor(tensor=param).get_handle()
t.stop()
return shared_memory_state_dict

@endpoint
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
"""Generate a response for the given prompt
Expand Down Expand Up @@ -400,6 +469,14 @@ async def update_weights(self, version: int) -> None:
>>> await trainer.push_weights()
>>> generator.update_weights(version)
"""
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
try:
pre_allocated = self.cached_state_dict_allocs.get_nowait()
except queue.Empty:
pre_allocated = None
fetch_task = asyncio.create_task(
self._fetch_weights(version, pre_allocated=pre_allocated)
)
# Serialize updates (only one update at a time)
async with self.update_lock:
# Grab the lock to stop accepting requests and wait on pending requests
Expand Down Expand Up @@ -431,8 +508,27 @@ async def update_weights(self, version: int) -> None:
)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
# Call update_weights on every generator_worker
await self.generator_worker.update_weights.call(version=version)
if not self.use_dcp:
# TODO: currently the alloc in ts.get will block the event loop unfortunately
# potentially we need to change torchstore
# We have to do this because Monarch future is not directly compatible with asyncio
t = Tracer("generator_perf/waiting_for_fetch_weights")
t.start()
fetched_weights = await fetch_task
t.stop()
# Call update_weights on every policy_worker
await self.generator_worker.update_weights.call(
shared_memory_state_dict=fetched_weights
)
try:
self.cached_state_dict_allocs.put_nowait(fetched_weights)
except queue.Full:
logger.info(
"Cached state dict alloc queue is full. Dropping allocated state dict."
)
await self._drop_shared_memory(fetched_weights)
else:
await self.generator_worker.update_weights.call(version=version)
self.generator_version = version

# After updating the weights, we need to reset the KV cache
Expand Down Expand Up @@ -504,6 +600,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
# TODO - may want to expand stop to gracefully respond to
# ongoing requests.
await actor.stop.call()
await actor._cleanup_shared_memory.call()
await stop_proc_mesh(actor._worker_procs)
await stop_proc_mesh(actor._generator_proc)

Expand Down Expand Up @@ -597,13 +694,37 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput:
return self.worker.execute_model(schedule)

@endpoint
async def update_weights(self, version: int) -> None:
async def update_weights(
self,
version: Optional[int] = None,
*,
shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None,
) -> None:
model = self.worker.model_runner.model
if shared_memory_state_dict is not None:
logger.info("[PolicyWorker] update weights from shared memory.")
t = Tracer(
"generator_worker_perf/update_weights_from_shared_memory", timer="gpu"
)
t.start()
loaded_weights = set()
for name, param_handle in shared_memory_state_dict.items():
param = param_handle.to_shared_tensor().tensor
loaded = model.load_weights([(name, param)])
loaded_weights.update(loaded)
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters")
t.stop()
return
if version is None:
raise ValueError(
"version must be provided if not using shared_memory_state_dict"
)
# If shared memory is not provided, we assume we are using DCP
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
loaded_weights = set()
t = Tracer("worker_perf/update_weights", timer="gpu")
t = Tracer("generator_worker_perf/update_weights", timer="gpu")
t.start()
# Entire state dict is stored in a single DCP handle
if dcp_whole_state_dict_key in matching_keys:
Expand All @@ -614,16 +735,8 @@ async def update_weights(self, version: int) -> None:
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
else: # Load each parameter from torchstore directly without DCP
hf_param_names = [extract_param_name(key) for key in matching_keys]
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
for name in hf_param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
else:
raise RuntimeError("No DCP handle found for the given version")
t.stop()

@endpoint
Expand Down
15 changes: 1 addition & 14 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import logging
import math
import os
import shutil

Expand All @@ -18,7 +17,7 @@
import torch.distributed.checkpoint as dcp
import torchstore as ts

from monarch.actor import current_rank, current_size, endpoint
from monarch.actor import endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torchtitan.config.job_config import (
Expand Down Expand Up @@ -163,19 +162,7 @@ def __post_init__(self):
self.step = 1 # fragile contract.
self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
self.size = math.prod(current_size().values())

env = {
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
"LOCAL_WORLD_SIZE": str(self.size),
"GROUP_RANK": str(self.size),
"GROUP_WORLD_SIZE": str(self.size),
"ROLE_RANK": str(self.rank),
"ROLE_WORLD_SIZE": str(self.size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self.size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)
Expand Down
10 changes: 9 additions & 1 deletion src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

from monarch.tools import commands

from forge.controller.launcher import BaseLauncher, get_launcher
from monarch.utils import setup_env_for_distributed

from forge.controller.launcher import BaseLauncher, get_launcher
from forge.env import all_env_vars, FORGE_DISABLE_METRICS
from forge.types import ProcessConfig, ProvisionerConfig

Expand Down Expand Up @@ -283,6 +284,13 @@ def bootstrap(env: dict[str, str]):
bootstrap=functools.partial(bootstrap, env=env_vars),
)

# Set up environment variables for PyTorch distributed...
await setup_env_for_distributed(
procs,
master_addr=addr,
master_port=port,
)

if is_remote:
await self.launcher.remote_setup(procs)

Expand Down
Loading
Loading