Skip to content
Merged
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
10d336a
shared tensor util
casteryh Oct 14, 2025
68395f9
refactor, fix test
casteryh Oct 14, 2025
3789d49
set up titan distributed through monarch utils, colocate policy with …
allenwang28 Oct 14, 2025
23c8fcb
add SharedTensorHandle class
casteryh Oct 14, 2025
f7ed526
merge
allenwang28 Oct 14, 2025
15d2784
shared memory weight loading
casteryh Oct 14, 2025
c7ca738
oopsie
casteryh Oct 14, 2025
cbd529e
end -> stop
casteryh Oct 14, 2025
20e162f
typo
casteryh Oct 14, 2025
eadf3a5
make policy_version optional
casteryh Oct 14, 2025
2b675f8
fix
casteryh Oct 14, 2025
8b488b7
no leak
casteryh Oct 14, 2025
5e7528c
disable dcp in 8b
casteryh Oct 14, 2025
2b23e50
undo the colocation
allenwang28 Oct 14, 2025
4b850ad
debug info
casteryh Oct 14, 2025
f7cbcb4
typo
casteryh Oct 14, 2025
09836f3
refactor
casteryh Oct 14, 2025
e5f984e
Merge branch 'main' into yhu/shared-tensor
casteryh Oct 14, 2025
9fa7395
temp: reduce num_replicas to 2
casteryh Oct 14, 2025
7744f4c
fix bad merge
casteryh Oct 14, 2025
8970ff4
revert to 4
casteryh Oct 14, 2025
48821de
move _fetch_weights to policy worker
casteryh Oct 15, 2025
2798f2e
typo
casteryh Oct 15, 2025
ddf8d26
endpoint
casteryh Oct 15, 2025
71b89c1
clean up
casteryh Oct 15, 2025
1971a4f
fix
casteryh Oct 15, 2025
dc301aa
fix
casteryh Oct 15, 2025
c879753
rearrange
casteryh Oct 15, 2025
c462911
log
casteryh Oct 15, 2025
571750f
vllm colocation works
allenwang28 Oct 15, 2025
a155a4c
caching allocation for weight updates
casteryh Oct 15, 2025
c2ab4e1
Merge branch 'titan_setup' into yhu/shared-tensor
casteryh Oct 15, 2025
9fa3297
test
casteryh Oct 15, 2025
68be7a7
multi processing for fetch
casteryh Oct 15, 2025
f30e666
endpoint
casteryh Oct 15, 2025
a190623
cleanup
casteryh Oct 15, 2025
af3f35c
fix
casteryh Oct 15, 2025
8fb6dde
fix
casteryh Oct 15, 2025
35c0052
fix queue pair checked out
casteryh Oct 15, 2025
555966f
typo
casteryh Oct 15, 2025
16b060b
spawn on worker procs
casteryh Oct 15, 2025
1e4110c
add back generator_proc
casteryh Oct 15, 2025
efe91e3
fix slice
casteryh Oct 15, 2025
d738f5e
fix fetcher call
casteryh Oct 15, 2025
957b4cd
fix arguments
casteryh Oct 15, 2025
be85a94
remove await
casteryh Oct 15, 2025
002d68b
make handle droppable
casteryh Oct 15, 2025
090b6ec
spawn in setup
casteryh Oct 15, 2025
1f9dc13
fix
casteryh Oct 15, 2025
289ca08
add t.stop()
casteryh Oct 15, 2025
e854ddd
create tasks
casteryh Oct 15, 2025
bee85df
pray
casteryh Oct 15, 2025
de3b41d
Revert "pray"
casteryh Oct 15, 2025
2b62964
Revert "create tasks"
casteryh Oct 15, 2025
aa57ba3
more procs
casteryh Oct 15, 2025
fab7cd4
procs=1
casteryh Oct 15, 2025
2d14c40
cleanup, guard with flag
casteryh Oct 15, 2025
acf72b8
Merge branch 'main' into yhu/shared-tensor-mp
casteryh Oct 15, 2025
d62b44f
clean up unused
casteryh Oct 15, 2025
6cf22d7
fork
casteryh Oct 15, 2025
bd921ca
Merge branch 'main' into yhu/shared-tensor-mp
casteryh Oct 15, 2025
7d56eef
remove cleanup which no longer exists
casteryh Oct 15, 2025
fc37e27
fix bad merge
casteryh Oct 15, 2025
bda4798
multiprocessing -> multiprocess
casteryh Oct 16, 2025
ab7dae5
separate spawn_fetchers method
casteryh Oct 16, 2025
e2785f7
rename flag
casteryh Oct 17, 2025
125b039
rename flag
casteryh Oct 17, 2025
b36eb9a
Merge branch 'main' into yhu/shared-tensor-mp
casteryh Oct 17, 2025
c9df35f
tp=8
casteryh Oct 17, 2025
dfd8656
Fix SharedTensor memory leaks with explicit lifecycle management
casteryh Oct 17, 2025
101b7a8
Use context manager pattern for SharedTensor in generator
casteryh Oct 17, 2025
ac9240c
Enable shared memory weight prefetching by default
casteryh Oct 17, 2025
9c35c31
Remove experimental config file
casteryh Oct 17, 2025
4c981bd
Merge remote-tracking branch 'origin/main' into yhu/shared-tensor-mp
casteryh Oct 17, 2025
a556064
Update src/forge/actors/generator.py
casteryh Oct 17, 2025
777a411
debug
casteryh Oct 17, 2025
db1394e
Add missing del param in _WeightFetcher.fetch()
casteryh Oct 17, 2025
856c070
Improve error handling in SharedTensor.close() and drop()
casteryh Oct 17, 2025
b3732de
Fix SharedTensor anti-pattern in multiprocess tests
casteryh Oct 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dev = [
"tomli>=1.1.0",
"anyio",
"pytest-asyncio",
"multiprocess",
]
oss = [
"torch",
Expand Down
133 changes: 128 additions & 5 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from collections.abc import Mapping
from copy import copy
from dataclasses import dataclass, field
from typing import Optional

import torch
import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
from monarch.actor import current_rank, endpoint, ProcMesh, this_host

from vllm.config import VllmConfig

from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -60,6 +62,7 @@
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -92,6 +95,8 @@ class Generator(ForgeActor):
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
use_dcp_for_weight_sync: bool | None = None
prefetch_weights_to_shm: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we should try to avoid changing the "public" api when we expect to quickly change the backend again. After launch we should try to keep this in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed.

n_fetcher_procs: int = 8

def __post_init__(self):
super().__init__()
Expand Down Expand Up @@ -226,11 +231,61 @@ async def setup(self):
log_stats=None,
)
self._start_processing()
if self.prefetch_weights_to_shm:
self._spawn_fetchers()

def _spawn_fetchers(self):
"""Spawn weight fetchers that prefetch weights from torchstore to shared memory."""
# TODO: this assumes the generator is on the same host as the worker
# and only works for single host generators. Figure out how to support
# generators with workers spanned across multiple hosts.
fetcher_procs = this_host().spawn_procs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we guard this with if prefetch_weights? It may also be a bit cleaner to put this init in another function w/ some documentation for what the fetchers here do

Copy link
Contributor Author

@casteryh casteryh Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will move to a separate function.

per_host={"procs": self.n_fetcher_procs}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that we need to address now, but I think we need to spawn these fetcher procs across all generator nodes right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. but I assume setup call is broadcasted and every Generator node will spawn their own fetcher_procs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is correct, I meant if a generator workers span 2 nodes i.e. DeepSeek

In that case we would probably want to spin up the fetchers on the worker nodes right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused -- shouldn't vLLM worker be scoped to a single node?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests?

)
self._fetcher_procs = fetcher_procs
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)

def _start_processing(self):
if self._run_task is None or self._run_task.done():
self._run_task = asyncio.create_task(self.run())

async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]):
for handle in state_dict.values():
handle.drop()

async def _fetch_weights(
self,
version: int,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
t = Tracer("generator_perf/_fetch_weights")
t.start()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]

n_fetchers = self.weight_fetchers.size()

def split_keys(keys):
return [keys[i::n_fetchers] for i in range(n_fetchers)]

futures = []
for i, names in enumerate(split_keys(hf_param_names)):
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
version=version, param_names=names
)
futures.append(fut)

sub_state_dicts = [await fut for fut in futures]

state_dict = {}
for sd in sub_state_dicts:
state_dict.update(sd)

t.stop()

return state_dict

@endpoint
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
"""Generate a response for the given prompt
Expand Down Expand Up @@ -384,6 +439,12 @@ async def update_weights(self, version: int) -> None:
>>> await trainer.push_weights()
>>> generator.update_weights(version)
"""
# Prefetch only if we are using RDMA
if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync:
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
fetch_fut = asyncio.create_task(self._fetch_weights(version))
else:
fetch_fut = None
# Serialize updates (only one update at a time)
async with self.update_lock:
# Grab the lock to stop accepting requests and wait on pending requests
Expand Down Expand Up @@ -415,8 +476,19 @@ async def update_weights(self, version: int) -> None:
)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
# Call update_weights on every generator worker
await self.worker.update_weights.call(version=version)

if fetch_fut is not None:
t = Tracer("generator_perf/waiting_for_fetch_weights")
t.start()
fetched_weights = await fetch_fut
t.stop()
# Call update_weights on every policy_worker
await self.worker.update_weights.call(
shared_memory_state_dict=fetched_weights
)
await self._drop_shared_memory(fetched_weights)
else:
await self.worker.update_weights.call(version=version)
self.generator_version = version

# After updating the weights, we need to reset the KV cache
Expand Down Expand Up @@ -490,6 +562,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
await actor.stop.call()
await stop_proc_mesh(actor._worker_procs)
await stop_proc_mesh(actor._generator_proc)
await stop_proc_mesh(actor._fetcher_procs)

@endpoint
async def save_model_params(self):
Expand Down Expand Up @@ -569,14 +642,41 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput:
return self.worker.execute_model(schedule)

@endpoint
async def update_weights(self, version: int) -> None:
async def update_weights(
self,
version: Optional[int] = None,
*,
shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None,
) -> None:
model = self.worker.model_runner.model
if shared_memory_state_dict is not None:
logger.info("[PolicyWorker] update weights from shared memory.")
t = Tracer(
"generator_worker_perf/update_weights_from_shared_memory", timer="gpu"
)
t.start()
loaded_weights = set()
for name, param_handle in shared_memory_state_dict.items():
# Use context manager for automatic cleanup
with param_handle.to_shared_tensor() as shared_tensor:
param = shared_tensor.tensor
loaded = model.load_weights([(name, param)])
loaded_weights.update(loaded)
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters")
t.stop()
return
# normal update_weights without shared memory prefetching
if version is None:
raise ValueError(
"version must be provided if not using shared_memory_state_dict"
)
logger.info("[PolicyWorker] update weights from torchstore.")
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
loaded_weights = set()
t = Tracer("worker_perf/update_weights", timer="gpu")
t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu")
t.start()

if use_dcp_for_weight_sync:
Expand Down Expand Up @@ -617,3 +717,26 @@ async def validate_model_params(self, validate_fn):
return validate_fn(
self._debug_saved_params, self.worker.model_runner.model, logger
)


class _WeightFetcher(ForgeActor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be a method on the generator that gets called from main, so prefetch is controlled and visible from the main loop. I am curious if this has to actually be a separate process since this is an async method and I would think most of the time it's waiting on ts.get.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be a separate actor because it has to be launched in a separate process

"""Fetches weights from torchstore and loads them into shared memory.
This has to be colocated with the GeneratorWorker."""

@endpoint
async def fetch(
self,
*,
version: int,
param_names: list[str],
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and load them into shared memory."""
sd = {}
for name in param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
# Use context manager to ensure cleanup after getting handle
with SharedTensor(tensor=param) as shared_tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to move this to TS and hide the rdma/shared memory logic from the user?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully yes.

handle = shared_tensor.get_handle()
sd[name] = handle
return sd
Loading
Loading