-
Notifications
You must be signed in to change notification settings - Fork 18
Add YAML-based configuration support for vLLM main #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
da21e1d
5c72908
b4d7a61
02d77c6
fd1d38b
f79beee
d8d775a
7301e10
64687d9
9278d75
d2d7107
14b5e4a
38f7927
2a1e021
412398c
d998061
8d38eb8
a3e755d
9dd396b
935fdc1
35fd71e
187a65d
0d26242
ba74b43
063afe6
a85f7b1
d94d326
5815656
2a31156
4778336
cb42997
eab380a
d575409
a72f4de
b19fe24
fc809f8
6ca7c2b
f1c24fb
4dc2e89
00c7fc9
4445624
a7dfd02
1ed76c4
c38685f
4191fa6
327828b
fe9acae
23e5ef6
7b904fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
policy: | ||
engine_config: | ||
model: "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: true | ||
sampling_config: | ||
n: 2 | ||
guided_decoding: false | ||
max_tokens: 512 | ||
available_devices: null | ||
service: | ||
procs_per_replica: 2 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
|
||
# Optional, otherwise argparse fallback kicks in | ||
prompt: "Tell me a joke" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,95 +6,32 @@ | |
|
||
"""To run: | ||
export HF_HUB_DISABLE_XET=1 | ||
python -m apps.vllm.main --guided-decoding --num-samples 3 | ||
|
||
python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml | ||
""" | ||
|
||
import argparse | ||
import asyncio | ||
from argparse import Namespace | ||
import sys | ||
|
||
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig | ||
from forge.actors.policy import Policy | ||
from forge.cli.config import parse | ||
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service | ||
from vllm.outputs import RequestOutput | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
|
||
async def main(): | ||
"""Main application for running vLLM policy inference.""" | ||
args = parse_args() | ||
|
||
# Create configuration objects | ||
policy_config, service_config = get_configs(args) | ||
|
||
# Resolve the Prompts | ||
if args.prompt is None: | ||
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke" | ||
else: | ||
prompt = args.prompt | ||
|
||
# format prompt | ||
tokenizer = get_tokenizer(policy_config.worker_params.model) | ||
messages = [{"role": "user", "content": prompt}] | ||
prompt = tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
|
||
# Run the policy | ||
await run_vllm(service_config, policy_config, prompt) | ||
|
||
|
||
def parse_args() -> Namespace: | ||
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application") | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default="Qwen/Qwen3-1.7B", # "meta-llama/Llama-3.1-8B-Instruct", | ||
help="Model to use", | ||
) | ||
parser.add_argument( | ||
"--num-samples", type=int, default=2, help="Number of samples to generate" | ||
) | ||
parser.add_argument( | ||
"--guided-decoding", action="store_true", help="Enable guided decoding" | ||
) | ||
parser.add_argument( | ||
"--prompt", type=str, default=None, help="Custom prompt to use for generation" | ||
) | ||
return parser.parse_args() | ||
from omegaconf import DictConfig | ||
from vllm.outputs import RequestOutput | ||
|
||
|
||
def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): | ||
async def run(cfg: DictConfig): | ||
|
||
worker_size = 2 | ||
worker_params = WorkerConfig( | ||
model=args.model, | ||
tensor_parallel_size=worker_size, | ||
pipeline_parallel_size=1, | ||
enforce_eager=True, | ||
vllm_args=None, | ||
) | ||
if (prompt := cfg.get("prompt")) is None: | ||
|
||
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False) | ||
prompt = "What is 3+5?" if gd else "Tell me a joke" | ||
|
||
sampling_params = SamplingOverrides( | ||
n=args.num_samples, | ||
guided_decoding=args.guided_decoding, | ||
max_tokens=16, | ||
) | ||
print("Spawning service...") | ||
|
||
policy_config = PolicyConfig( | ||
worker_params=worker_params, sampling_params=sampling_params | ||
) | ||
service_config = ServiceConfig( | ||
procs_per_replica=worker_size, num_replicas=1, with_gpus=True | ||
policy = await spawn_service( | ||
ServiceConfig(**cfg.policy.service), Policy, **cfg.policy | ||
|
||
) | ||
|
||
return policy_config, service_config | ||
|
||
|
||
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str): | ||
print("Spawning service...") | ||
policy = await spawn_service(service_config, Policy, config=config) | ||
|
||
async with policy.session(): | ||
print("Requesting generation...") | ||
response_output: RequestOutput = await policy.generate.choose(prompt=prompt) | ||
|
@@ -112,5 +49,10 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: | |
await shutdown_service(policy) | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) | ||
sys.exit(recipe_main()) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -8,8 +8,9 @@ | |||||||||
import logging | ||||||||||
import os | ||||||||||
import sys | ||||||||||
from collections.abc import Mapping | ||||||||||
from copy import copy | ||||||||||
from dataclasses import asdict, dataclass, field | ||||||||||
from dataclasses import asdict, dataclass, field, fields | ||||||||||
from typing import Dict, List | ||||||||||
|
||||||||||
import torch | ||||||||||
|
@@ -48,7 +49,7 @@ | |||||||||
|
||||||||||
|
||||||||||
@dataclass | ||||||||||
class SamplingOverrides: | ||||||||||
class SamplingConfig: | ||||||||||
""" | ||||||||||
Overrides for vLLMs sampling params. | ||||||||||
|
||||||||||
|
@@ -62,7 +63,7 @@ class SamplingOverrides: | |||||||||
max_tokens: Maximum number of tokens to generate. | ||||||||||
""" | ||||||||||
|
||||||||||
n: int | ||||||||||
n: int = 1 | ||||||||||
guided_decoding: bool = False | ||||||||||
max_tokens: int = 512 | ||||||||||
|
||||||||||
|
@@ -72,37 +73,39 @@ def __post_init__(self): | |||||||||
gd_params = GuidedDecodingParams(choice=["Positive", "Negative"]) | ||||||||||
self.guided_decoding = gd_params | ||||||||||
|
||||||||||
@classmethod | ||||||||||
def from_dict(cls, d: Mapping): | ||||||||||
d = dict(d) | ||||||||||
all_fields = set(cls.__dataclass_fields__.keys()) | ||||||||||
valid_args = {k: v for k, v in d.items() if k in all_fields} | ||||||||||
return cls(**valid_args) | ||||||||||
|
||||||||||
|
||||||||||
@dataclass | ||||||||||
class WorkerConfig: | ||||||||||
class EngineConfig(EngineArgs): | ||||||||||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
""" | ||||||||||
Config args used for setting up the policy worker. | ||||||||||
|
||||||||||
Args: | ||||||||||
model: Model name. | ||||||||||
tensor_parallel_size: Number of tensor parallel workers. | ||||||||||
pipeline_parallel_size: Number of pipeline parallel workers. | ||||||||||
enforce_eager: Whether to enforce eager mode. | ||||||||||
vllm_args: vLLM engine args. | ||||||||||
EngineConfig extends EngineArgs with worker-specific fields. | ||||||||||
Overlapping keys in input dict will override EngineArgs defaults. | ||||||||||
|
EngineConfig extends EngineArgs with worker-specific fields. | |
Overlapping keys in input dict will override EngineArgs defaults. | |
EngineConfig extends EngineArgs surfacing worker-specific fields. | |
Args of this class override EngineArgs defaults. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read it here that tensor_parallel_size
is under EngineConfig.parallel_config.tensor_parallel_size
. If so, Is this implementation correct? Should the user pass the value like this instead:
policy:
engine_params:
parallel_config:
tensor_parallel_size = 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I comment on this below, what we have is fine since parallel_config doesn't actually exist until create_engine_config
is called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of scope for this PR, but we should think about the "service in yaml pattern" when we have some breathing room
We're gonna have a pattern of excluding this field when passings args around (since
X.service
is not a common Agent Arg)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you be more clear with the suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally think we should have spawn_service handle this to make it less awkward but we can do that later.
Something like
where
spawn_service(actor: Actor, service_config: ServiceConfig | Mapping, **kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, no action required here
Seconding the API too