-
Notifications
You must be signed in to change notification settings - Fork 24
Weight loading working correctly with tp: use vllm builtin load_weights() #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
17e0c05
2657324
6b67be9
a6c7aef
7ec461b
762bf24
87a7bc2
98e6dd3
830efcf
5a6245c
3cf2e32
1189017
b9291cb
46e855f
68b91a4
d1e7ec6
4d61a58
d224bea
8c3471c
0f67dec
a08c96d
ab56d05
33ec0ef
7e05b3a
aaf60ec
1c0afa9
3e5a417
7addf9a
446b123
c02e988
30fbe46
9676aa0
0691b43
fed8688
26e1334
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Toy app Training Configuration | ||
|
||
# Global configuration | ||
group_size: 16 | ||
batch_size: 64 | ||
max_req_tokens: 64 | ||
max_res_tokens: 64 | ||
model: "Qwen/Qwen2.5-0.5B-Instruct" | ||
|
||
# Dataset configuration | ||
dataset: | ||
model: ${model} | ||
|
||
# Policy configuration | ||
policy: | ||
engine_config: | ||
model: ${model} | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: false | ||
sampling_config: | ||
n: ${group_size} | ||
max_tokens: ${max_res_tokens} | ||
temperature: 1.0 | ||
top_p: 1.0 | ||
use_vllm_builtin_load: true | ||
|
||
# Trainer configuration | ||
trainer: | ||
model_name: ${model} | ||
learning_rate: 1e-5 | ||
use_vllm_builtin_load: true | ||
|
||
# Reference model configuration | ||
ref_model: | ||
model_name: ${model} | ||
|
||
# Replay buffer configuration | ||
replay_buffer: | ||
batch_size: ${batch_size} | ||
max_policy_age: 1 # Async by 1 | ||
dp_size: 1 | ||
|
||
services: | ||
dataset: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
policy: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
trainer: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
replay_buffer: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
reward_actor: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
ref_model: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true |
casteryh marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
import torch.distributed.checkpoint as dcp | ||
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta | ||
|
||
KEY_DELIM = "." | ||
|
||
|
||
@dataclass | ||
class DcpHandle: | ||
checkpoint_id: str = "" | ||
metadata: DcpMeta | None = None | ||
|
||
|
||
def load_tensor_from_dcp(handle: DcpHandle, param_name) -> torch.Tensor: | ||
tensor_meta = handle.metadata.state_dict_metadata[param_name] | ||
buffer = torch.empty(tensor_meta.size, dtype=tensor_meta.properties.dtype) | ||
dcp.load(checkpoint_id=handle.checkpoint_id, state_dict={param_name: buffer}) | ||
return buffer | ||
|
||
|
||
def get_param_prefix(policy_version: int) -> str: | ||
return f"policy_ver_{policy_version}" | ||
|
||
|
||
def get_param_key(policy_version: int, name: str) -> str: | ||
return f"policy_ver_{policy_version}{KEY_DELIM}{name}" | ||
|
||
|
||
def extract_param_name(key: str) -> str: | ||
return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,12 +42,18 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.actors._torchstore_utils import ( | ||
DcpHandle, | ||
extract_param_name, | ||
get_param_key, | ||
get_param_prefix, | ||
load_tensor_from_dcp, | ||
) | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.data.sharding import VLLMSharding | ||
from forge.data_models.completion import Completion | ||
from forge.data_models.prompt import to_prompt | ||
|
||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
|
||
|
@@ -126,6 +132,7 @@ def create_vllm_config(self) -> VllmConfig: | |
class Policy(PolicyInterface): | ||
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig) | ||
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig) | ||
use_vllm_builtin_load: bool = True | ||
available_devices: str | None = None | ||
# Gets set up by setup | ||
sampling_params: SamplingParams | None = None | ||
|
@@ -144,6 +151,7 @@ def __post_init__(self): | |
self.engine_config = EngineConfig.from_dict(self.engine_config) | ||
if isinstance(self.sampling_config, Mapping): | ||
self.sampling_config = SamplingConfig.from_dict(self.sampling_config) | ||
# No conversion needed for boolean flag | ||
|
||
@classmethod | ||
async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | ||
|
@@ -196,6 +204,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
sampling_config=sampling_config, | ||
available_devices=available_devices, | ||
policy_worker=workers, | ||
**kwargs, | ||
) | ||
policy._policy_proc = policy_proc | ||
policy._worker_procs = worker_procs | ||
|
@@ -387,7 +396,22 @@ async def update_weights(self, policy_version: int): | |
await asyncio.gather(*curr_requests) | ||
|
||
logger.debug(f"Starting weight update on {self.__class__.__name__}") | ||
await self.policy_worker.update.call(version=policy_version) | ||
if self.use_vllm_builtin_load: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually, this will be the default right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like the plan |
||
await self.policy_worker.update.call(version=policy_version) | ||
else: | ||
await self.policy_worker.update_DEPRECATED.call(version=policy_version) | ||
self.policy_version = policy_version | ||
logger.info(f"Weight update completed (now v{self.policy_version})") | ||
|
||
@endpoint | ||
async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802 | ||
# TODO: If generating long sequences, this might be long and will block policy weight updates | ||
curr_requests = [fut for _, fut in self.requests.values()] | ||
if curr_requests: | ||
logger.debug(f"Waiting for {len(curr_requests)} pending requests") | ||
await asyncio.gather(*curr_requests) | ||
|
||
await self.policy_worker.update_DEPRECATED.call(version=policy_version) | ||
self.policy_version = policy_version | ||
logger.info(f"Weight update completed (now v{self.policy_version})") | ||
|
||
|
@@ -454,7 +478,11 @@ def _extract_logprobs(self, one_sample: CompletionOutput) -> torch.Tensor | None | |
class PolicyWorker(ForgeActor): | ||
vllm_config: VllmConfig | ||
state_dict_key: str = "model_state_dict" | ||
# TODO: remove this later since no plumbing exists to change this value. | ||
# Also, whether to use dcp or not can be inferred from torchstore get() call. | ||
use_dcp: bool = True | ||
# Cache hf param names on first update call. | ||
hf_param_names = [] | ||
|
||
# used for tesing purposes only | ||
_test_prev_params = {} | ||
|
@@ -509,8 +537,9 @@ async def _load_tensor_parallel_state_dict( | |
) | ||
|
||
@endpoint | ||
async def update(self, version: int): | ||
"""Update model weights by reading state dict from torchstore""" | ||
async def update_DEPRECATED(self, version: int): # noqa: N802 | ||
"""Update model weights by reading state dict from torchstore. | ||
Deprecated. This uses manual sharding logic which is buggy.""" | ||
key = f"{self.state_dict_key}{DELIM}{version}" | ||
model = self.worker.model_runner.model | ||
current_state_dict = model.state_dict() | ||
|
@@ -520,6 +549,40 @@ async def update(self, version: int): | |
f"Loaded state dict from {key} in {time.perf_counter() - start} seconds" | ||
) | ||
|
||
@endpoint | ||
async def update(self, version: int): | ||
"""Update model weights by reading state dict from torchstore""" | ||
model = self.worker.model_runner.model | ||
prefix = get_param_prefix(version) | ||
logger.debug(f"{prefix=}") | ||
matching_keys = await ts.keys(prefix) | ||
logger.debug(f"{matching_keys=}") | ||
if not self.hf_param_names: | ||
self.hf_param_names = [extract_param_name(key) for key in matching_keys] | ||
loaded_weights = set() | ||
# We can't pass a generator since vllm load_weights is not async. | ||
# Instead, we just call load_weights with one parameter at a time. | ||
start = time.perf_counter() | ||
for name in self.hf_param_names: | ||
param_key = get_param_key(version, name) | ||
tensor_or_handle = await ts.get(param_key) | ||
if isinstance(tensor_or_handle, torch.Tensor): | ||
param = tensor_or_handle | ||
elif isinstance(tensor_or_handle, DcpHandle): | ||
param = load_tensor_from_dcp(tensor_or_handle, name) | ||
logger.debug(f"Loaded {name} from DCP with handle {tensor_or_handle}") | ||
else: | ||
raise RuntimeError( | ||
f"Unexpected type for {param_key}: {type(tensor_or_handle)}" | ||
) | ||
loaded = model.load_weights([(name, param)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is super cool! I didn't realize you could do it per-param :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it's surprisingly good |
||
del param | ||
loaded_weights.update(loaded) | ||
logger.info( | ||
f"[PolicyWorker::update] Updated {len(loaded_weights)} parameters, took {time.perf_counter() - start} seconds" | ||
) | ||
logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") | ||
|
||
@endpoint | ||
async def setup_kv_cache(self): | ||
"""Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this won't work lemme revert