-
Notifications
You must be signed in to change notification settings - Fork 17
Add YAML-based configuration support for vLLM main #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 45 commits
da21e1d
5c72908
b4d7a61
02d77c6
fd1d38b
f79beee
d8d775a
7301e10
64687d9
9278d75
d2d7107
14b5e4a
38f7927
2a1e021
412398c
d998061
8d38eb8
a3e755d
9dd396b
935fdc1
35fd71e
187a65d
0d26242
ba74b43
063afe6
a85f7b1
d94d326
5815656
2a31156
4778336
cb42997
eab380a
d575409
a72f4de
b19fe24
fc809f8
6ca7c2b
f1c24fb
4dc2e89
00c7fc9
4445624
a7dfd02
1ed76c4
c38685f
4191fa6
327828b
fe9acae
23e5ef6
7b904fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
policy: | ||
engine_params: | ||
model: "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: true | ||
sampling_overrides: | ||
n: 2 | ||
guided_decoding: false | ||
max_tokens: 512 | ||
available_devices: null | ||
service: | ||
procs_per_replica: 2 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
|
||
# Optional, otherwise argparse fallback kicks in | ||
prompt: "Tell me a joke" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,95 +6,36 @@ | |
|
||
"""To run: | ||
export HF_HUB_DISABLE_XET=1 | ||
python -m apps.vllm.main --guided-decoding --num-samples 3 | ||
|
||
python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml | ||
""" | ||
|
||
import argparse | ||
import asyncio | ||
from argparse import Namespace | ||
import sys | ||
|
||
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig | ||
from forge.actors.policy import Policy | ||
from forge.cli.config import parse | ||
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service | ||
from vllm.outputs import RequestOutput | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
from omegaconf import DictConfig | ||
from vllm.outputs import RequestOutput | ||
|
||
async def main(): | ||
"""Main application for running vLLM policy inference.""" | ||
args = parse_args() | ||
|
||
# Create configuration objects | ||
policy_config, service_config = get_configs(args) | ||
async def run(cfg: DictConfig): | ||
|
||
# Resolve the Prompts | ||
if args.prompt is None: | ||
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke" | ||
if (prompt := cfg.get("prompt")) is None: | ||
|
||
gd = cfg.policy.get("sampling_overrides", {}).get("guided_decoding", False) | ||
prompt = "What is 3+5?" if gd else "Tell me a joke" | ||
prompt = cfg["prompt"] | ||
else: | ||
prompt = args.prompt | ||
|
||
# format prompt | ||
tokenizer = get_tokenizer(policy_config.worker_params.model) | ||
messages = [{"role": "user", "content": prompt}] | ||
prompt = tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
|
||
# Run the policy | ||
await run_vllm(service_config, policy_config, prompt) | ||
|
||
|
||
def parse_args() -> Namespace: | ||
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application") | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default="Qwen/Qwen3-1.7B", # "meta-llama/Llama-3.1-8B-Instruct", | ||
help="Model to use", | ||
) | ||
parser.add_argument( | ||
"--num-samples", type=int, default=2, help="Number of samples to generate" | ||
) | ||
parser.add_argument( | ||
"--guided-decoding", action="store_true", help="Enable guided decoding" | ||
) | ||
parser.add_argument( | ||
"--prompt", type=str, default=None, help="Custom prompt to use for generation" | ||
) | ||
return parser.parse_args() | ||
|
||
gd = cfg.policy.get("sampling_overrides", {}).get("guided_decoding", False) | ||
prompt = "What is 3+5?" if gd else "Tell me a joke" | ||
|
||
def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): | ||
|
||
worker_size = 2 | ||
worker_params = WorkerConfig( | ||
model=args.model, | ||
tensor_parallel_size=worker_size, | ||
pipeline_parallel_size=1, | ||
enforce_eager=True, | ||
vllm_args=None, | ||
) | ||
|
||
sampling_params = SamplingOverrides( | ||
n=args.num_samples, | ||
guided_decoding=args.guided_decoding, | ||
max_tokens=16, | ||
) | ||
print("Spawning service...") | ||
|
||
policy_config = PolicyConfig( | ||
worker_params=worker_params, sampling_params=sampling_params | ||
) | ||
service_config = ServiceConfig( | ||
procs_per_replica=worker_size, num_replicas=1, with_gpus=True | ||
policy = await spawn_service( | ||
ServiceConfig(**cfg.policy.service), Policy, **cfg.policy | ||
|
||
) | ||
|
||
return policy_config, service_config | ||
|
||
|
||
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str): | ||
print("Spawning service...") | ||
policy = await spawn_service(service_config, Policy, config=config) | ||
|
||
async with policy.session(): | ||
print("Requesting generation...") | ||
response_output: RequestOutput = await policy.generate.choose(prompt=prompt) | ||
|
@@ -112,5 +53,10 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: | |
await shutdown_service(policy) | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) | ||
sys.exit(recipe_main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import logging | ||
import os | ||
import sys | ||
from collections.abc import Mapping | ||
from copy import copy | ||
from dataclasses import asdict, dataclass, field | ||
from typing import Dict, List | ||
|
@@ -62,7 +63,7 @@ class SamplingOverrides: | |
max_tokens: Maximum number of tokens to generate. | ||
""" | ||
|
||
n: int | ||
n: int = 1 | ||
guided_decoding: bool = False | ||
max_tokens: int = 512 | ||
|
||
|
@@ -72,37 +73,43 @@ def __post_init__(self): | |
gd_params = GuidedDecodingParams(choice=["Positive", "Negative"]) | ||
self.guided_decoding = gd_params | ||
|
||
@classmethod | ||
def from_dict(cls, d: Mapping): | ||
d = dict(d) | ||
all_fields = set(cls.__dataclass_fields__.keys()) | ||
valid_args = {k: v for k, v in d.items() if k in all_fields} | ||
return cls(**valid_args) | ||
|
||
|
||
@dataclass | ||
class WorkerConfig: | ||
class EngineArgOverrides(EngineArgs): | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
Config args used for setting up the policy worker. | ||
|
||
Args: | ||
model: Model name. | ||
tensor_parallel_size: Number of tensor parallel workers. | ||
pipeline_parallel_size: Number of pipeline parallel workers. | ||
enforce_eager: Whether to enforce eager mode. | ||
vllm_args: vLLM engine args. | ||
EngineArgOverrides extends EngineArgs with worker-specific fields. | ||
Overlapping keys in input dict will override EngineArgs defaults. | ||
""" | ||
|
||
model: str | ||
model: str = "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: int = 1 | ||
|
||
pipeline_parallel_size: int = 1 | ||
enforce_eager: bool = False | ||
vllm_args: EngineArgs = None | ||
|
||
|
||
@dataclass | ||
class PolicyConfig: | ||
worker_params: WorkerConfig | ||
sampling_params: SamplingOverrides | ||
available_devices: str = None | ||
@classmethod | ||
def from_dict(cls, d: Mapping): | ||
d = dict(d) | ||
all_fields = set(cls.__dataclass_fields__.keys()) | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
valid_args = {k: v for k, v in d.items() if k in all_fields} | ||
return cls(**valid_args) | ||
|
||
|
||
@dataclass | ||
class Policy(PolicyInterface): | ||
config: PolicyConfig | ||
engine_params: EngineArgOverrides | Mapping = field( | ||
default_factory=EngineArgOverrides | ||
) | ||
sampling_overrides: SamplingOverrides | Mapping = field( | ||
default_factory=SamplingOverrides | ||
) | ||
available_devices: str | None = None | ||
# Gets set up by setup | ||
sampling_params: SamplingParams | None = None | ||
lora_request: LoRARequest | None = None | ||
|
@@ -115,13 +122,21 @@ def __post_init__(self): | |
self._policy_proc: ProcMesh | None = None | ||
self._worker_procs: ProcMesh | None = None | ||
self.weights_version: int = 0 | ||
if isinstance(self.engine_params, Mapping): | ||
self.engine_params = EngineArgOverrides.from_dict(self.engine_params) | ||
if isinstance(self.sampling_overrides, Mapping): | ||
self.sampling_overrides = SamplingOverrides.from_dict( | ||
self.sampling_overrides | ||
) | ||
|
||
@classmethod | ||
async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | ||
cls: type["Policy"], | ||
*, | ||
process_config: ProcessConfig, | ||
config: PolicyConfig, | ||
engine_params: EngineArgOverrides | Mapping = EngineArgOverrides(), | ||
sampling_overrides: SamplingOverrides | Mapping = SamplingOverrides(), | ||
available_devices: str | None = None, | ||
store: MultiProcessStore | None = None, | ||
**kwargs, | ||
) -> "Policy": | ||
|
@@ -135,16 +150,25 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | |
policy_proc_config.with_gpus = False | ||
|
||
policy_proc = await get_proc_mesh(process_config=policy_proc_config) | ||
|
||
if isinstance(engine_params, Mapping): | ||
engine_params = EngineArgOverrides.from_dict(engine_params) | ||
|
||
if isinstance(engine_params, Mapping): | ||
sampling_overrides = SamplingOverrides(**sampling_overrides) | ||
|
||
workers = await worker_procs.spawn( | ||
"vllm_worker", PolicyWorker, **asdict(config.worker_params) | ||
"vllm_worker", PolicyWorker, vllm_args=engine_params | ||
) | ||
|
||
# TODO - expand support so name can stick within kwargs | ||
actor_name = kwargs.pop("name", cls.__name__) | ||
policy = await policy_proc.spawn( | ||
actor_name, | ||
cls, | ||
config=config, | ||
engine_params=engine_params, | ||
sampling_overrides=sampling_overrides, | ||
available_devices=available_devices, | ||
policy_worker=workers, | ||
store=store, | ||
) | ||
|
@@ -182,7 +206,7 @@ async def setup(self): | |
|
||
# Setup sampling params | ||
self.sampling_params = get_default_sampling_params( | ||
self.vllm_args, overrides=asdict(self.config.sampling_params) | ||
self.vllm_args, overrides=asdict(self.sampling_overrides) | ||
) | ||
|
||
# Setup processors | ||
|
@@ -348,11 +372,7 @@ async def stop(self): | |
|
||
@dataclass | ||
class PolicyWorker(ForgeActor): | ||
model: str | ||
tensor_parallel_size: int = 1 | ||
pipeline_parallel_size: int = 1 | ||
enforce_eager: bool = False | ||
vllm_args: EngineArgs = None | ||
vllm_args: EngineArgOverrides | Mapping = EngineArgOverrides() | ||
state_dict_key: str = "model_state_dict" | ||
|
||
def __post_init__(self): | ||
|
@@ -368,31 +388,11 @@ def __post_init__(self): | |
- all LLM generate methods, verify against LLM inputs | ||
- all executor methods verify no changes | ||
""" | ||
if self.vllm_args is None: | ||
# Use default vllm EngineArgs | ||
self.vllm_args = EngineArgs( | ||
model=self.model, | ||
tensor_parallel_size=self.tensor_parallel_size, | ||
pipeline_parallel_size=self.pipeline_parallel_size, | ||
enforce_eager=self.enforce_eager, | ||
) | ||
# Original method returns False when not run in the main thread | ||
self.vllm_args._is_v1_supported_oracle = lambda *_: True | ||
else: | ||
# Check that provided args match Policy args | ||
cfg = [ | ||
"model", | ||
"tensor_parallel_size", | ||
"pipeline_parallel_size", | ||
"data_parallel_size", | ||
] | ||
for key in cfg: | ||
value = getattr(self, key) if key != "data_parallel_size" else 1 | ||
if getattr(self.vllm_args, key) != value: | ||
logger.warning( | ||
f"{key} args don't match value in EngineArgs, overriding with {value}" | ||
) | ||
setattr(self.vllm_args, key, value) | ||
if isinstance(self.vllm_args, Mapping): | ||
self.vllm_args = EngineArgOverrides.from_dict(self.vllm_args) | ||
|
||
# Original method returns False when not run in the main thread | ||
self.vllm_args._is_v1_supported_oracle = lambda *_: True | ||
# Build Config | ||
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) | ||
|
||
|
@@ -416,7 +416,9 @@ async def _load_tensor_parallel_state_dict( | |
|
||
updated_count = 0 | ||
# setting explictly to llama3 for now as its our only use case | ||
sharding = VLLMSharding(self.tensor_parallel_size, self.rank) | ||
sharding = VLLMSharding( | ||
self.vllm_args.parallel_config.tensor_parallel_size, self.rank | ||
) | ||
|
||
for param_name in current_state_dict.keys(): | ||
current_tensor = current_state_dict[param_name] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of scope for this PR, but we should think about the "service in yaml pattern" when we have some breathing room
We're gonna have a pattern of excluding this field when passings args around (since
X.service
is not a common Agent Arg)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you be more clear with the suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally think we should have spawn_service handle this to make it less awkward but we can do that later.
Something like
where
spawn_service(actor: Actor, service_config: ServiceConfig | Mapping, **kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, no action required here
Seconding the API too