-
Notifications
You must be signed in to change notification settings - Fork 24
Weight loading working correctly with tp: use vllm builtin load_weights() #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
17e0c05
2657324
6b67be9
a6c7aef
7ec461b
762bf24
87a7bc2
98e6dd3
830efcf
5a6245c
3cf2e32
1189017
b9291cb
46e855f
68b91a4
d1e7ec6
4d61a58
d224bea
8c3471c
0f67dec
a08c96d
ab56d05
33ec0ef
7e05b3a
aaf60ec
1c0afa9
3e5a417
7addf9a
446b123
c02e988
30fbe46
9676aa0
0691b43
fed8688
26e1334
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Toy app Training Configuration | ||
|
||
# Global configuration | ||
group_size: 16 | ||
batch_size: 64 | ||
max_req_tokens: 64 | ||
max_res_tokens: 64 | ||
model: "Qwen/Qwen2.5-0.5B-Instruct" | ||
|
||
# Dataset configuration | ||
dataset: | ||
model: ${model} | ||
|
||
# Policy configuration | ||
policy: | ||
engine_config: | ||
model: ${model} | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: false | ||
sampling_config: | ||
n: ${group_size} | ||
max_tokens: ${max_res_tokens} | ||
temperature: 1.0 | ||
top_p: 1.0 | ||
use_vllm_builtin_load: true | ||
|
||
# Trainer configuration | ||
trainer: | ||
model_name: ${model} | ||
learning_rate: 1e-5 | ||
use_vllm_builtin_load: true | ||
|
||
# Reference model configuration | ||
ref_model: | ||
model_name: ${model} | ||
|
||
# Replay buffer configuration | ||
replay_buffer: | ||
batch_size: ${batch_size} | ||
max_policy_age: 1 # Async by 1 | ||
dp_size: 1 | ||
|
||
services: | ||
dataset: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
policy: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
trainer: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
replay_buffer: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
reward_actor: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
ref_model: | ||
procs: 1 | ||
num_replicas: 1 | ||
with_gpus: true |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,12 +43,16 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.actors.torchstore_utils import ( | ||
extract_param_name, | ||
get_param_key, | ||
get_param_prefix, | ||
) | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.data.sharding import VLLMSharding | ||
from forge.data_models.completion import Completion | ||
from forge.data_models.prompt import to_prompt | ||
|
||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
|
||
|
@@ -127,6 +131,8 @@ def create_vllm_config(self) -> VllmConfig: | |
class Policy(PolicyInterface): | ||
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig) | ||
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig) | ||
use_vllm_builtin_load: bool = False | ||
test_blah_blah: int = 0 | ||
available_devices: str | None = None | ||
# Gets set up by setup | ||
sampling_params: SamplingParams | None = None | ||
|
@@ -145,6 +151,7 @@ def __post_init__(self): | |
self.engine_config = EngineConfig.from_dict(self.engine_config) | ||
if isinstance(self.sampling_config, Mapping): | ||
self.sampling_config = SamplingConfig.from_dict(self.sampling_config) | ||
# No conversion needed for boolean flag | ||
|
||
@classmethod | ||
async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | ||
|
@@ -153,6 +160,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
process_config: ProcessConfig, | ||
engine_config: EngineConfig | Mapping = EngineConfig(), | ||
sampling_config: SamplingConfig | Mapping = SamplingConfig(), | ||
use_vllm_builtin_load: bool = False, | ||
available_devices: str | None = None, | ||
**kwargs, | ||
) -> "Policy": | ||
|
@@ -191,6 +199,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
cls, | ||
engine_config=engine_config, | ||
sampling_config=sampling_config, | ||
use_vllm_builtin_load=use_vllm_builtin_load, | ||
available_devices=available_devices, | ||
policy_worker=workers, | ||
) | ||
|
@@ -384,7 +393,10 @@ async def update_weights(self, policy_version: int): | |
await asyncio.gather(*curr_requests) | ||
|
||
logger.debug(f"Starting weight update on {self.__class__.__name__}") | ||
await self.policy_worker.update.call(version=policy_version) | ||
if self.use_vllm_builtin_load: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually, this will be the default right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like the plan |
||
await self.policy_worker._update_hf_nonsharded.call(version=policy_version) | ||
else: | ||
await self.policy_worker._update_sharded.call(version=policy_version) | ||
self.policy_version = policy_version | ||
logger.info(f"Weight update completed (now v{self.policy_version})") | ||
|
||
|
@@ -496,7 +508,7 @@ async def _load_tensor_parallel_state_dict( | |
) | ||
|
||
@endpoint | ||
async def update(self, version: int): | ||
async def _update_sharded(self, version: int): | ||
"""Update model weights by reading state dict from torchstore""" | ||
key = f"{self.state_dict_key}{DELIM}{version}" | ||
model = self.worker.model_runner.model | ||
|
@@ -505,6 +517,27 @@ async def update(self, version: int): | |
await self._load_tensor_parallel_state_dict(current_state_dict, version) | ||
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds") | ||
|
||
@endpoint | ||
async def _update_hf_nonsharded(self, version: int): | ||
|
||
"""Update model weights by reading state dict from torchstore""" | ||
model = self.worker.model_runner.model | ||
prefix = get_param_prefix(version) | ||
self.logger.debug(f"{prefix=}") | ||
matching_keys = await ts.keys(prefix) | ||
self.logger.debug(f"{matching_keys=}") | ||
# TODO: find a way to save the original huggingface parameter names. | ||
hf_names = [extract_param_name(key) for key in matching_keys] | ||
self.logger.debug(f"{hf_names=}") | ||
loaded_weights = set() | ||
# We can't pass a generator since vllm load_weights is not async. | ||
# Instead, we just call load_weights with one parameter at a time. | ||
for name in hf_names: | ||
param = await ts.get(get_param_key(version, name)) | ||
loaded = model.load_weights([(name, param)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is super cool! I didn't realize you could do it per-param :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it's surprisingly good |
||
del param | ||
loaded_weights.update(loaded) | ||
self.logger.info(f"Updated {len(loaded_weights)} parameters") | ||
|
||
|
||
@endpoint | ||
async def setup_kv_cache(self): | ||
"""Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
KEY_DELIM = "." | ||
|
||
|
||
def get_param_prefix(policy_version: int) -> str: | ||
return f"policy_ver_{policy_version}" | ||
|
||
|
||
def get_param_key(policy_version: int, name: str) -> str: | ||
return f"policy_ver_{policy_version}{KEY_DELIM}{name}" | ||
|
||
|
||
def extract_param_name(key: str) -> str: | ||
return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,8 @@ | |
from torchtitan.experiments.forge.engine import ForgeEngine | ||
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
|
||
from forge.actors.torchstore_utils import get_param_key | ||
|
||
from forge.controller import ForgeActor | ||
from forge.data.utils import batch_to_device | ||
|
||
|
@@ -93,6 +95,7 @@ class RLTrainer(ForgeActor): | |
activation_checkpoint: ActivationCheckpoint = field( | ||
default_factory=ActivationCheckpoint | ||
) | ||
use_vllm_builtin_load: bool = False | ||
compile: Compile = field(default_factory=Compile) | ||
float8: Float8 = field(default_factory=Float8) | ||
comm: Comm = field(default_factory=Comm) | ||
|
@@ -142,7 +145,7 @@ def __post_init__(self): | |
async def setup(self): | ||
# TODO: update ForgeEngine to not use ForgeJobConfig | ||
engine_config = {f.name: getattr(self, f.name) for f in fields(self)} | ||
for key in {"loss", "state_dict_key", "use_dcp"}: | ||
for key in {"loss", "state_dict_key", "use_dcp", "use_vllm_builtin_load"}: | ||
engine_config.pop(key) # Not part of job config | ||
self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) | ||
self.engine.checkpointer.load(step=self.step) | ||
|
@@ -248,6 +251,12 @@ def train_step( | |
|
||
@endpoint | ||
async def push_weights(self, policy_version: int) -> None: | ||
if self.use_vllm_builtin_load: | ||
await self._push_weights_hf_nonsharded(policy_version) | ||
else: | ||
await self._push_weights_sharded(policy_version) | ||
|
||
async def _push_weights_sharded(self, policy_version: int) -> None: | ||
|
||
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. | ||
# TODO: | ||
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. | ||
|
@@ -290,6 +299,22 @@ async def push_weights(self, policy_version: int) -> None: | |
|
||
logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds") | ||
|
||
async def _push_weights_hf_nonsharded(self, policy_version: int) -> None: | ||
"""Push weights to torchstore in HF format, non-sharded.""" | ||
if "model" not in self.engine.checkpointer.states: | ||
raise RuntimeError("Model state not found in checkpointer state") | ||
|
||
sd = self.engine.checkpointer.states["model"].state_dict() | ||
flattened_state_dict, _ = flatten_state_dict(sd) | ||
if self.engine.checkpointer.sd_adapter is None: | ||
raise RuntimeError( | ||
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." | ||
) | ||
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) | ||
for name, param in hf_state_dict.items(): | ||
key = get_param_key(policy_version, name) | ||
await ts.put(key, param) | ||
|
||
@endpoint | ||
async def cleanup(self) -> None: | ||
if self.engine.checkpointer: | ||
|
Uh oh!
There was an error while loading. Please reload this page.