-
Notifications
You must be signed in to change notification settings - Fork 17
Add YAML-based configuration support for vLLM main #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
da21e1d
5c72908
b4d7a61
02d77c6
fd1d38b
f79beee
d8d775a
7301e10
64687d9
9278d75
d2d7107
14b5e4a
38f7927
2a1e021
412398c
d998061
8d38eb8
a3e755d
9dd396b
935fdc1
35fd71e
187a65d
0d26242
ba74b43
063afe6
a85f7b1
d94d326
5815656
2a31156
4778336
cb42997
eab380a
d575409
a72f4de
b19fe24
fc809f8
6ca7c2b
f1c24fb
4dc2e89
00c7fc9
4445624
a7dfd02
1ed76c4
c38685f
4191fa6
327828b
fe9acae
23e5ef6
7b904fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
policy: | ||
worker_params: | ||
model: "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: true | ||
vllm_args: null | ||
sampling_params: | ||
num_samples: 2 | ||
guided_decoding: false | ||
available_devices: null | ||
|
||
service_config: | ||
procs_per_replica: 2 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Optional, otherwise argparse fallback kicks in | ||
prompt: "Tell me a joke" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,18 +8,14 @@ | |
import logging | ||
import os | ||
import sys | ||
from collections.abc import Mapping | ||
from copy import copy | ||
from dataclasses import asdict, dataclass, field | ||
from typing import Dict, List | ||
|
||
import torch | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
|
||
from forge.data.sharding import VLLMSharding | ||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
from monarch.actor import current_rank, endpoint, ProcMesh | ||
from omegaconf import DictConfig | ||
from torchstore import MultiProcessStore | ||
from torchstore._state_dict_utils import DELIM | ||
|
||
|
@@ -43,6 +39,12 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
|
||
from forge.data.sharding import VLLMSharding | ||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -61,7 +63,7 @@ class SamplingOverrides: | |
guided_decoding: Whether to use guided decoding. | ||
""" | ||
|
||
num_samples: int | ||
num_samples: int = 1 | ||
|
||
guided_decoding: bool = False | ||
max_tokens: int = 512 | ||
|
||
|
@@ -79,23 +81,27 @@ class WorkerConfig: | |
vllm_args: vLLM engine args. | ||
""" | ||
|
||
model: str | ||
model: str = "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: int = 1 | ||
|
||
pipeline_parallel_size: int = 1 | ||
enforce_eager: bool = False | ||
vllm_args: EngineArgs = None | ||
|
||
vllm_args: EngineArgs = field(default_factory=EngineArgs) | ||
|
||
@dataclass | ||
class PolicyConfig: | ||
worker_params: WorkerConfig | ||
sampling_params: SamplingOverrides | ||
available_devices: str = None | ||
@classmethod | ||
def from_dict(cls, d: dict): | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
d = dict(d) | ||
if "vllm_args" in d and isinstance(d["vllm_args"], dict): | ||
d["vllm_args"] = EngineArgs(**d["vllm_args"]) | ||
else: | ||
d["vllm_args"] = EngineArgs() | ||
return cls(**d) | ||
|
||
|
||
@dataclass | ||
class Policy(PolicyInterface): | ||
config: PolicyConfig | ||
worker_params: WorkerConfig = field(default_factory=WorkerConfig) | ||
sampling_overrides: SamplingOverrides = field(default_factory=SamplingOverrides) | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
available_devices: str | None = None | ||
# Gets set up by setup | ||
sampling_params: SamplingParams | None = None | ||
lora_request: LoRARequest | None = None | ||
|
@@ -108,13 +114,19 @@ def __post_init__(self): | |
self._policy_proc: ProcMesh | None = None | ||
self._worker_procs: ProcMesh | None = None | ||
self.weights_version: int = 0 | ||
if isinstance(self.worker_params, Mapping): | ||
self.worker_params = WorkerConfig.from_dict(self.worker_params) | ||
if isinstance(self.sampling_overrides, dict): | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self.sampling_overrides = SamplingOverrides(**self.sampling_overrides) | ||
|
||
@classmethod | ||
async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | ||
cls: type["Policy"], | ||
*, | ||
process_config: ProcessConfig, | ||
config: PolicyConfig, | ||
worker_params: WorkerConfig | dict = WorkerConfig(), | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
sampling_overrides: SamplingOverrides | dict = SamplingOverrides(), | ||
available_devices: str | None = None, | ||
store: MultiProcessStore | None = None, | ||
**kwargs, | ||
) -> "Policy": | ||
|
@@ -128,16 +140,26 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
policy_proc_config.with_gpus = False | ||
|
||
policy_proc = await get_proc_mesh(process_config=policy_proc_config) | ||
workers = await worker_procs.spawn( | ||
"vllm_worker", PolicyWorker, **asdict(config.worker_params) | ||
) | ||
|
||
if isinstance(worker_params, (dict, DictConfig)): | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
worker_params = WorkerConfig.from_dict(worker_params) | ||
|
||
if isinstance(worker_params, (dict, DictConfig)): | ||
sampling_overrides = SamplingOverrides(**sampling_overrides) | ||
|
||
worker_dict = asdict(worker_params) | ||
worker_dict["vllm_args"] = worker_params.vllm_args | ||
|
||
workers = await worker_procs.spawn("vllm_worker", PolicyWorker, **worker_dict) | ||
|
||
# TODO - expand support so name can stick within kwargs | ||
actor_name = kwargs.pop("name", cls.__name__) | ||
policy = await policy_proc.spawn( | ||
actor_name, | ||
cls, | ||
config=config, | ||
worker_params=worker_params, | ||
sampling_overrides=sampling_overrides, | ||
available_devices=available_devices, | ||
policy_worker=workers, | ||
store=store, | ||
) | ||
|
@@ -174,7 +196,7 @@ async def setup(self): | |
self.vllm_args = await self.policy_worker.get_vllm_args.choose() | ||
|
||
# Setup sampling params | ||
sampling_overrides = self.config.sampling_params | ||
sampling_overrides = self.sampling_overrides | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
overrides = { | ||
"n": sampling_overrides.num_samples, | ||
"guided_decoding": ( | ||
|
@@ -371,8 +393,6 @@ def __post_init__(self): | |
pipeline_parallel_size=self.pipeline_parallel_size, | ||
enforce_eager=self.enforce_eager, | ||
) | ||
# Original method returns False when not run in the main thread | ||
self.vllm_args._is_v1_supported_oracle = lambda *_: True | ||
else: | ||
# Check that provided args match Policy args | ||
cfg = [ | ||
|
@@ -388,6 +408,8 @@ def __post_init__(self): | |
f"{key} args don't match value in EngineArgs, overriding with {value}" | ||
) | ||
setattr(self.vllm_args, key, value) | ||
# Original method returns False when not run in the main thread | ||
self.vllm_args._is_v1_supported_oracle = lambda *_: True | ||
# Build Config | ||
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.