-
Notifications
You must be signed in to change notification settings - Fork 24
Weight loading working correctly with tp: use vllm builtin load_weights() #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
17e0c05
2657324
6b67be9
a6c7aef
7ec461b
762bf24
87a7bc2
98e6dd3
830efcf
5a6245c
3cf2e32
1189017
b9291cb
46e855f
68b91a4
d1e7ec6
4d61a58
d224bea
8c3471c
0f67dec
a08c96d
ab56d05
33ec0ef
7e05b3a
aaf60ec
1c0afa9
3e5a417
7addf9a
446b123
c02e988
30fbe46
9676aa0
0691b43
fed8688
26e1334
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Toy app Training Configuration | ||
|
||
# Global configuration | ||
group_size: 16 | ||
batch_size: 64 | ||
max_req_tokens: 64 | ||
max_res_tokens: 64 | ||
model: "Qwen/Qwen2.5-0.5B-Instruct" | ||
|
||
# Dataset configuration | ||
dataset: | ||
model: ${model} | ||
|
||
# Policy configuration | ||
policy: | ||
engine_config: | ||
model: ${model} | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: false | ||
sampling_config: | ||
n: ${group_size} | ||
max_tokens: ${max_res_tokens} | ||
temperature: 1.0 | ||
top_p: 1.0 | ||
use_vllm_builtin_load: true | ||
|
||
# Trainer configuration | ||
trainer: | ||
model_name: ${model} | ||
learning_rate: 1e-5 | ||
use_vllm_builtin_load: true | ||
|
||
# Reference model configuration | ||
ref_model: | ||
model_name: ${model} | ||
|
||
# Replay buffer configuration | ||
replay_buffer: | ||
batch_size: ${batch_size} | ||
max_policy_age: 1 # Async by 1 | ||
dp_size: 1 | ||
|
||
services: | ||
dataset: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
policy: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
trainer: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
replay_buffer: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
reward_actor: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
ref_model: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,12 +43,16 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.actors.torchstore_utils import ( | ||
extract_param_name, | ||
get_param_key, | ||
get_param_prefix, | ||
) | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.data.sharding import VLLMSharding | ||
from forge.data_models.completion import Completion | ||
from forge.data_models.prompt import to_prompt | ||
|
||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
|
||
|
@@ -127,6 +131,7 @@ def create_vllm_config(self) -> VllmConfig: | |
class Policy(PolicyInterface): | ||
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig) | ||
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig) | ||
use_vllm_builtin_load: bool = True | ||
available_devices: str | None = None | ||
# Gets set up by setup | ||
sampling_params: SamplingParams | None = None | ||
|
@@ -145,6 +150,7 @@ def __post_init__(self): | |
self.engine_config = EngineConfig.from_dict(self.engine_config) | ||
if isinstance(self.sampling_config, Mapping): | ||
self.sampling_config = SamplingConfig.from_dict(self.sampling_config) | ||
# No conversion needed for boolean flag | ||
|
||
@classmethod | ||
async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | ||
|
@@ -153,6 +159,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
process_config: ProcessConfig, | ||
engine_config: EngineConfig | Mapping = EngineConfig(), | ||
sampling_config: SamplingConfig | Mapping = SamplingConfig(), | ||
use_vllm_builtin_load: bool = False, | ||
available_devices: str | None = None, | ||
**kwargs, | ||
) -> "Policy": | ||
|
@@ -191,6 +198,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
cls, | ||
engine_config=engine_config, | ||
sampling_config=sampling_config, | ||
use_vllm_builtin_load=use_vllm_builtin_load, | ||
available_devices=available_devices, | ||
policy_worker=workers, | ||
) | ||
|
@@ -384,7 +392,22 @@ async def update_weights(self, policy_version: int): | |
await asyncio.gather(*curr_requests) | ||
|
||
logger.debug(f"Starting weight update on {self.__class__.__name__}") | ||
await self.policy_worker.update.call(version=policy_version) | ||
if self.use_vllm_builtin_load: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually, this will be the default right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like the plan |
||
await self.policy_worker.update.call(version=policy_version) | ||
else: | ||
await self.policy_worker.update_DEPRECATED.call(version=policy_version) | ||
self.policy_version = policy_version | ||
logger.info(f"Weight update completed (now v{self.policy_version})") | ||
|
||
@endpoint | ||
async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802 | ||
# TODO: If generating long sequences, this might be long and will block policy weight updates | ||
curr_requests = [fut for _, fut in self.requests.values()] | ||
if curr_requests: | ||
logger.debug(f"Waiting for {len(curr_requests)} pending requests") | ||
await asyncio.gather(*curr_requests) | ||
|
||
await self.policy_worker.update_DEPRECATED.call(version=policy_version) | ||
self.policy_version = policy_version | ||
logger.info(f"Weight update completed (now v{self.policy_version})") | ||
|
||
|
@@ -496,15 +519,37 @@ async def _load_tensor_parallel_state_dict( | |
) | ||
|
||
@endpoint | ||
async def update(self, version: int): | ||
"""Update model weights by reading state dict from torchstore""" | ||
async def update_DEPRECATED(self, version: int): # noqa: N802 | ||
"""Update model weights by reading state dict from torchstore. | ||
Deprecated. This uses manual sharding logic which is buggy.""" | ||
key = f"{self.state_dict_key}{DELIM}{version}" | ||
model = self.worker.model_runner.model | ||
current_state_dict = model.state_dict() | ||
start = time.time() | ||
await self._load_tensor_parallel_state_dict(current_state_dict, version) | ||
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds") | ||
|
||
@endpoint | ||
async def update(self, version: int): | ||
"""Update model weights by reading state dict from torchstore""" | ||
model = self.worker.model_runner.model | ||
prefix = get_param_prefix(version) | ||
self.logger.debug(f"{prefix=}") | ||
matching_keys = await ts.keys(prefix) | ||
self.logger.debug(f"{matching_keys=}") | ||
# TODO: find a way to save the original huggingface parameter names. | ||
hf_names = [extract_param_name(key) for key in matching_keys] | ||
self.logger.debug(f"{hf_names=}") | ||
loaded_weights = set() | ||
# We can't pass a generator since vllm load_weights is not async. | ||
# Instead, we just call load_weights with one parameter at a time. | ||
for name in hf_names: | ||
param = await ts.get(get_param_key(version, name)) | ||
loaded = model.load_weights([(name, param)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is super cool! I didn't realize you could do it per-param :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it's surprisingly good |
||
del param | ||
loaded_weights.update(loaded) | ||
self.logger.info(f"Updated {len(loaded_weights)} parameters") | ||
|
||
|
||
@endpoint | ||
async def setup_kv_cache(self): | ||
"""Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
KEY_DELIM = "." | ||
|
||
|
||
def get_param_prefix(policy_version: int) -> str: | ||
return f"policy_ver_{policy_version}" | ||
|
||
|
||
def get_param_key(policy_version: int, name: str) -> str: | ||
return f"policy_ver_{policy_version}{KEY_DELIM}{name}" | ||
|
||
|
||
def extract_param_name(key: str) -> str: | ||
return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're confident in this fix, we should just fully delete the old way. My thinking is as follows:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the new one is the default now! I think the plan is keep the DEPRECATED method just for benchmarking purposes now? @JenniferWang