-
Notifications
You must be signed in to change notification settings - Fork 16
shared memory multiprocess prefetch for weight update #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 74 commits
10d336a
68395f9
3789d49
23c8fcb
f7ed526
15d2784
c7ca738
cbd529e
20e162f
eadf3a5
2b675f8
8b488b7
5e7528c
2b23e50
4b850ad
f7cbcb4
09836f3
e5f984e
9fa7395
7744f4c
8970ff4
48821de
2798f2e
ddf8d26
71b89c1
1971a4f
dc301aa
c879753
c462911
571750f
a155a4c
c2ab4e1
9fa3297
68be7a7
f30e666
a190623
af3f35c
8fb6dde
35c0052
555966f
16b060b
1e4110c
efe91e3
d738f5e
957b4cd
be85a94
002d68b
090b6ec
1f9dc13
289ca08
e854ddd
bee85df
de3b41d
2b62964
aa57ba3
fab7cd4
2d14c40
acf72b8
d62b44f
6cf22d7
bd921ca
7d56eef
fc37e27
bda4798
ab7dae5
e2785f7
125b039
b36eb9a
c9df35f
dfd8656
101b7a8
ac9240c
9c35c31
4c981bd
a556064
777a411
db1394e
856c070
b3732de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ dev = [ | |
"tomli>=1.1.0", | ||
"anyio", | ||
"pytest-asyncio", | ||
"multiprocess", | ||
] | ||
oss = [ | ||
"torch", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,10 +13,12 @@ | |
from collections.abc import Mapping | ||
from copy import copy | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
import torch | ||
import torchstore as ts | ||
from monarch.actor import current_rank, endpoint, ProcMesh | ||
from monarch.actor import current_rank, endpoint, ProcMesh, this_host | ||
|
||
from vllm.config import VllmConfig | ||
|
||
from vllm.engine.arg_utils import EngineArgs | ||
|
@@ -60,6 +62,7 @@ | |
from forge.observability.metrics import record_metric, Reduce | ||
from forge.observability.perf_tracker import Tracer | ||
from forge.types import ProcessConfig | ||
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
@@ -92,6 +95,8 @@ class Generator(ForgeActor): | |
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) | ||
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) | ||
use_dcp_for_weight_sync: bool | None = None | ||
prefetch_weights_to_shm: bool = True | ||
n_fetcher_procs: int = 8 | ||
|
||
def __post_init__(self): | ||
super().__init__() | ||
|
@@ -226,11 +231,61 @@ async def setup(self): | |
log_stats=None, | ||
) | ||
self._start_processing() | ||
if self.prefetch_weights_to_shm: | ||
self._spawn_fetchers() | ||
|
||
def _spawn_fetchers(self): | ||
"""Spawn weight fetchers that prefetch weights from torchstore to shared memory.""" | ||
# TODO: this assumes the generator is on the same host as the worker | ||
# and only works for single host generators. Figure out how to support | ||
# generators with workers spanned across multiple hosts. | ||
fetcher_procs = this_host().spawn_procs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we guard this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will move to a separate function. |
||
per_host={"procs": self.n_fetcher_procs} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not that we need to address now, but I think we need to spawn these fetcher procs across all generator nodes right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. but I assume There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that is correct, I meant if a generator workers span 2 nodes i.e. DeepSeek In that case we would probably want to spin up the fetchers on the worker nodes right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a bit confused -- shouldn't vLLM worker be scoped to a single node? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests? |
||
) | ||
self._fetcher_procs = fetcher_procs | ||
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher) | ||
|
||
def _start_processing(self): | ||
if self._run_task is None or self._run_task.done(): | ||
self._run_task = asyncio.create_task(self.run()) | ||
|
||
async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): | ||
for handle in state_dict.values(): | ||
handle.drop() | ||
|
||
async def _fetch_weights( | ||
self, | ||
version: int, | ||
) -> dict[str, SharedTensorHandle]: | ||
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" | ||
t = Tracer("generator_perf/_fetch_weights") | ||
t.start() | ||
prefix = get_param_prefix(version) | ||
matching_keys = await ts.keys(prefix) | ||
hf_param_names = [extract_param_name(key) for key in matching_keys] | ||
|
||
n_fetchers = self.weight_fetchers.size() | ||
|
||
def split_keys(keys): | ||
return [keys[i::n_fetchers] for i in range(n_fetchers)] | ||
|
||
futures = [] | ||
for i, names in enumerate(split_keys(hf_param_names)): | ||
fut = self.weight_fetchers.slice(procs=i).fetch.call_one( | ||
version=version, param_names=names | ||
) | ||
futures.append(fut) | ||
|
||
sub_state_dicts = [await fut for fut in futures] | ||
|
||
state_dict = {} | ||
for sd in sub_state_dicts: | ||
state_dict.update(sd) | ||
|
||
t.stop() | ||
|
||
return state_dict | ||
|
||
@endpoint | ||
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: | ||
"""Generate a response for the given prompt | ||
|
@@ -384,6 +439,12 @@ async def update_weights(self, version: int) -> None: | |
>>> await trainer.push_weights() | ||
>>> generator.update_weights(version) | ||
""" | ||
# Prefetch only if we are using RDMA | ||
casteryh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: | ||
logger.info(f"[Generator] Fetching weights for v{version} to shared memory") | ||
fetch_fut = asyncio.create_task(self._fetch_weights(version)) | ||
else: | ||
fetch_fut = None | ||
# Serialize updates (only one update at a time) | ||
async with self.update_lock: | ||
# Grab the lock to stop accepting requests and wait on pending requests | ||
|
@@ -415,8 +476,19 @@ async def update_weights(self, version: int) -> None: | |
) | ||
|
||
logger.debug(f"Starting weight update on {self.__class__.__name__}") | ||
# Call update_weights on every generator worker | ||
await self.worker.update_weights.call(version=version) | ||
|
||
if fetch_fut is not None: | ||
t = Tracer("generator_perf/waiting_for_fetch_weights") | ||
t.start() | ||
fetched_weights = await fetch_fut | ||
t.stop() | ||
# Call update_weights on every policy_worker | ||
await self.worker.update_weights.call( | ||
shared_memory_state_dict=fetched_weights | ||
) | ||
await self._drop_shared_memory(fetched_weights) | ||
else: | ||
await self.worker.update_weights.call(version=version) | ||
self.generator_version = version | ||
|
||
# After updating the weights, we need to reset the KV cache | ||
|
@@ -490,6 +562,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] | |
await actor.stop.call() | ||
await stop_proc_mesh(actor._worker_procs) | ||
await stop_proc_mesh(actor._generator_proc) | ||
await stop_proc_mesh(actor._fetcher_procs) | ||
|
||
@endpoint | ||
async def save_model_params(self): | ||
|
@@ -569,14 +642,41 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: | |
return self.worker.execute_model(schedule) | ||
|
||
@endpoint | ||
async def update_weights(self, version: int) -> None: | ||
async def update_weights( | ||
self, | ||
version: Optional[int] = None, | ||
*, | ||
shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, | ||
) -> None: | ||
model = self.worker.model_runner.model | ||
if shared_memory_state_dict is not None: | ||
logger.info("[PolicyWorker] update weights from shared memory.") | ||
t = Tracer( | ||
"generator_worker_perf/update_weights_from_shared_memory", timer="gpu" | ||
) | ||
t.start() | ||
loaded_weights = set() | ||
for name, param_handle in shared_memory_state_dict.items(): | ||
# Use context manager for automatic cleanup | ||
with param_handle.to_shared_tensor() as shared_tensor: | ||
param = shared_tensor.tensor | ||
loaded = model.load_weights([(name, param)]) | ||
loaded_weights.update(loaded) | ||
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") | ||
t.stop() | ||
return | ||
# normal update_weights without shared memory prefetching | ||
if version is None: | ||
raise ValueError( | ||
"version must be provided if not using shared_memory_state_dict" | ||
) | ||
logger.info("[PolicyWorker] update weights from torchstore.") | ||
prefix = get_param_prefix(version) | ||
matching_keys = await ts.keys(prefix) | ||
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) | ||
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys | ||
loaded_weights = set() | ||
t = Tracer("worker_perf/update_weights", timer="gpu") | ||
t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu") | ||
t.start() | ||
|
||
if use_dcp_for_weight_sync: | ||
|
@@ -617,3 +717,26 @@ async def validate_model_params(self, validate_fn): | |
return validate_fn( | ||
self._debug_saved_params, self.worker.model_runner.model, logger | ||
) | ||
|
||
|
||
class _WeightFetcher(ForgeActor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this could be a method on the generator that gets called from main, so prefetch is controlled and visible from the main loop. I am curious if this has to actually be a separate process since this is an async method and I would think most of the time it's waiting on ts.get. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has to be a separate actor because it has to be launched in a separate process |
||
"""Fetches weights from torchstore and loads them into shared memory. | ||
This has to be colocated with the GeneratorWorker.""" | ||
|
||
@endpoint | ||
async def fetch( | ||
self, | ||
*, | ||
version: int, | ||
param_names: list[str], | ||
) -> dict[str, SharedTensorHandle]: | ||
"""Fetch weights from torchstore and load them into shared memory.""" | ||
sd = {} | ||
for name in param_names: | ||
param_key = get_param_key(version, name) | ||
param = await ts.get(param_key) | ||
# Use context manager to ensure cleanup after getting handle | ||
with SharedTensor(tensor=param) as shared_tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the plan to move this to TS and hide the rdma/shared memory logic from the user? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully yes. |
||
handle = shared_tensor.get_handle() | ||
sd[name] = handle | ||
return sd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general we should try to avoid changing the "public" api when we expect to quickly change the backend again. After launch we should try to keep this in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed.