-
Notifications
You must be signed in to change notification settings - Fork 24
Weight loading working correctly with tp: use vllm builtin load_weights() #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
17e0c05
2657324
6b67be9
a6c7aef
7ec461b
762bf24
87a7bc2
98e6dd3
830efcf
5a6245c
3cf2e32
1189017
b9291cb
46e855f
68b91a4
d1e7ec6
4d61a58
d224bea
8c3471c
0f67dec
a08c96d
ab56d05
33ec0ef
7e05b3a
aaf60ec
1c0afa9
3e5a417
7addf9a
446b123
c02e988
30fbe46
9676aa0
0691b43
fed8688
26e1334
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
import torch | ||
import torchstore as ts | ||
from monarch.actor import current_rank, endpoint, ProcMesh | ||
from torchstore.state_dict_utils import DELIM | ||
from vllm.config import VllmConfig | ||
|
||
from vllm.engine.arg_utils import EngineArgs | ||
|
@@ -40,11 +39,17 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.actors.torchstore_utils import ( | ||
extract_param_name, | ||
get_param_key, | ||
get_param_prefix, | ||
) | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.data.sharding import VLLMSharding | ||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
from forge.util.async_utils import make_sync_generator | ||
|
||
|
||
@dataclass | ||
|
@@ -364,16 +369,16 @@ async def run(self): | |
fut.set_result(request_output) | ||
|
||
@endpoint | ||
async def update_weights(self): | ||
async def update_weights(self, policy_version: int): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You probably want to rebase on this #181 |
||
# TODO: If generating long sequences, this might be long and will block policy weight updates | ||
curr_requests = [fut for _, fut in self.requests.values()] | ||
if curr_requests: | ||
self.logger.debug(f"Waiting for {len(curr_requests)} pending requests") | ||
await asyncio.gather(*curr_requests) | ||
|
||
self.logger.debug(f"Starting weight update on {self.__class__.__name__}") | ||
await self.policy_worker.update.call(version=self.weights_version) | ||
self.weights_version += 1 | ||
await self.policy_worker.update.call(version=policy_version) | ||
self.weights_version = policy_version | ||
self.logger.info(f"Weight update completed (now v{self.weights_version})") | ||
|
||
@endpoint | ||
|
@@ -395,7 +400,6 @@ async def stop(self): | |
@dataclass | ||
class PolicyWorker(ForgeActor): | ||
vllm_config: VllmConfig | ||
state_dict_key: str = "model_state_dict" | ||
|
||
@endpoint | ||
async def setup(self): | ||
|
@@ -407,41 +411,26 @@ async def setup(self): | |
async def execute_model(self, schedule: SchedulerOutput): | ||
return self.worker.execute_model(schedule) | ||
|
||
async def _load_tensor_parallel_state_dict( | ||
self, current_state_dict: dict, version: int | ||
): | ||
""" | ||
Load full state dict from torchstore into tensor parallel model with deterministic sharding. | ||
""" | ||
sharding = VLLMSharding( | ||
self.vllm_config.parallel_config.tensor_parallel_size, self.rank | ||
) | ||
|
||
for param_name in current_state_dict.keys(): | ||
current_tensor = current_state_dict[param_name] | ||
|
||
# Load the full tensor from torchstore | ||
# TODO: only get the part of the tensor that is needed | ||
stored_tensor = await ts.get( | ||
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}" | ||
) | ||
sharding.load_from_source_to_target( | ||
param_name, | ||
stored_tensor, | ||
current_tensor, | ||
) | ||
|
||
@endpoint | ||
async def update(self, version: int): | ||
"""Update model weights by reading state dict from torchstore""" | ||
key = f"{self.state_dict_key}{DELIM}{version}" | ||
model = self.worker.model_runner.model | ||
current_state_dict = model.state_dict() | ||
start = time.time() | ||
await self._load_tensor_parallel_state_dict(current_state_dict, version) | ||
self.logger.debug( | ||
f"Loaded state dict from {key} in {time.time() - start} seconds" | ||
) | ||
prefix = get_param_prefix(version) | ||
self.logger.debug(f"{prefix=}") | ||
matching_keys = await ts.keys(prefix) | ||
self.logger.debug(f"{matching_keys=}") | ||
# TODO: find a way to save the original huggingface parameter names. | ||
hf_names = [extract_param_name(key) for key in matching_keys] | ||
self.logger.debug(f"{hf_names=}") | ||
loaded_weights = set() | ||
# We can't pass a generator since vllm load_weights is not async. | ||
# Instead, we just call load_weights with one parameter at a time. | ||
for name in hf_names: | ||
param = await ts.get(get_param_key(version, name)) | ||
loaded = model.load_weights([(name, param)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is super cool! I didn't realize you could do it per-param :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it's surprisingly good |
||
del param | ||
loaded_weights.update(loaded) | ||
self.logger.info(f"Updated {len(loaded_weights)} parameters") | ||
|
||
|
||
@endpoint | ||
async def setup_kv_cache(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
KEY_DELIM = "." | ||
|
||
|
||
def get_param_prefix(policy_version: int) -> str: | ||
return f"policy_ver_{policy_version}" | ||
|
||
|
||
def get_param_key(policy_version: int, name: str) -> str: | ||
return f"policy_ver_{policy_version}{KEY_DELIM}{name}" | ||
|
||
|
||
def extract_param_name(key: str) -> str: | ||
return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I think let's split this diff to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added this for debugging myself, will do!