-
Notifications
You must be signed in to change notification settings - Fork 17
Add YAML-based configuration support for vLLM main #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 39 commits
da21e1d
5c72908
b4d7a61
02d77c6
fd1d38b
f79beee
d8d775a
7301e10
64687d9
9278d75
d2d7107
14b5e4a
38f7927
2a1e021
412398c
d998061
8d38eb8
a3e755d
9dd396b
935fdc1
35fd71e
187a65d
0d26242
ba74b43
063afe6
a85f7b1
d94d326
5815656
2a31156
4778336
cb42997
eab380a
d575409
a72f4de
b19fe24
fc809f8
6ca7c2b
f1c24fb
4dc2e89
00c7fc9
4445624
a7dfd02
1ed76c4
c38685f
4191fa6
327828b
fe9acae
23e5ef6
7b904fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
policy: | ||
engine_params: | ||
model: "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: true | ||
sampling_overrides: | ||
n: 2 | ||
guided_decoding: false | ||
max_tokens: 512 | ||
available_devices: null | ||
service: | ||
procs_per_replica: 2 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
|
||
|
||
# Optional, otherwise argparse fallback kicks in | ||
prompt: "Tell me a joke" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,95 +6,34 @@ | |
|
||
"""To run: | ||
export HF_HUB_DISABLE_XET=1 | ||
python -m apps.vllm.main --guided-decoding --num-samples 3 | ||
|
||
python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml | ||
""" | ||
|
||
import argparse | ||
import asyncio | ||
from argparse import Namespace | ||
import sys | ||
|
||
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig | ||
from forge.actors.policy import Policy | ||
from forge.cli.config import parse | ||
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service | ||
from vllm.outputs import RequestOutput | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
from omegaconf import DictConfig | ||
from vllm.outputs import RequestOutput | ||
|
||
async def main(): | ||
"""Main application for running vLLM policy inference.""" | ||
args = parse_args() | ||
|
||
# Create configuration objects | ||
policy_config, service_config = get_configs(args) | ||
async def run(cfg: DictConfig): | ||
|
||
# Resolve the Prompts | ||
if args.prompt is None: | ||
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke" | ||
if "prompt" in cfg and cfg["prompt"] is not None: | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
prompt = cfg["prompt"] | ||
else: | ||
prompt = args.prompt | ||
|
||
# format prompt | ||
tokenizer = get_tokenizer(policy_config.worker_params.model) | ||
messages = [{"role": "user", "content": prompt}] | ||
prompt = tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
|
||
# Run the policy | ||
await run_vllm(service_config, policy_config, prompt) | ||
|
||
|
||
def parse_args() -> Namespace: | ||
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application") | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default="Qwen/Qwen3-1.7B", # "meta-llama/Llama-3.1-8B-Instruct", | ||
help="Model to use", | ||
) | ||
parser.add_argument( | ||
"--num-samples", type=int, default=2, help="Number of samples to generate" | ||
) | ||
parser.add_argument( | ||
"--guided-decoding", action="store_true", help="Enable guided decoding" | ||
) | ||
parser.add_argument( | ||
"--prompt", type=str, default=None, help="Custom prompt to use for generation" | ||
) | ||
return parser.parse_args() | ||
|
||
gd = cfg.policy.get("sampling_overrides", {}).get("guided_decoding", False) | ||
prompt = "What is 3+5?" if gd else "Tell me a joke" | ||
|
||
def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): | ||
|
||
worker_size = 2 | ||
worker_params = WorkerConfig( | ||
model=args.model, | ||
tensor_parallel_size=worker_size, | ||
pipeline_parallel_size=1, | ||
enforce_eager=True, | ||
vllm_args=None, | ||
) | ||
|
||
sampling_params = SamplingOverrides( | ||
n=args.num_samples, | ||
guided_decoding=args.guided_decoding, | ||
max_tokens=16, | ||
) | ||
print("Spawning service...") | ||
|
||
policy_config = PolicyConfig( | ||
worker_params=worker_params, sampling_params=sampling_params | ||
) | ||
service_config = ServiceConfig( | ||
procs_per_replica=worker_size, num_replicas=1, with_gpus=True | ||
policy = await spawn_service( | ||
ServiceConfig(**cfg.policy.service), Policy, **cfg.policy | ||
|
||
) | ||
|
||
return policy_config, service_config | ||
|
||
|
||
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str): | ||
print("Spawning service...") | ||
policy = await spawn_service(service_config, Policy, config=config) | ||
|
||
async with policy.session(): | ||
print("Requesting generation...") | ||
response_output: RequestOutput = await policy.generate.choose(prompt=prompt) | ||
|
@@ -112,5 +51,10 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: | |
await shutdown_service(policy) | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) | ||
sys.exit(recipe_main()) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -8,6 +8,7 @@ | |||||||||
import logging | ||||||||||
import os | ||||||||||
import sys | ||||||||||
from collections.abc import Mapping | ||||||||||
from copy import copy | ||||||||||
from dataclasses import asdict, dataclass, field | ||||||||||
from typing import Dict, List | ||||||||||
|
@@ -62,7 +63,7 @@ class SamplingOverrides: | |||||||||
max_tokens: Maximum number of tokens to generate. | ||||||||||
""" | ||||||||||
|
||||||||||
n: int | ||||||||||
n: int = 1 | ||||||||||
guided_decoding: bool = False | ||||||||||
max_tokens: int = 512 | ||||||||||
|
||||||||||
|
@@ -72,37 +73,39 @@ def __post_init__(self): | |||||||||
gd_params = GuidedDecodingParams(choice=["Positive", "Negative"]) | ||||||||||
self.guided_decoding = gd_params | ||||||||||
|
||||||||||
@classmethod | ||||||||||
def from_dict(cls, d: Mapping): | ||||||||||
d = dict(d) | ||||||||||
all_fields = set(cls.__dataclass_fields__.keys()) | ||||||||||
valid_args = {k: v for k, v in d.items() if k in all_fields} | ||||||||||
return cls(**valid_args) | ||||||||||
|
||||||||||
|
||||||||||
@dataclass | ||||||||||
class WorkerConfig: | ||||||||||
class EngineConfig(EngineArgs): | ||||||||||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
""" | ||||||||||
Config args used for setting up the policy worker. | ||||||||||
|
||||||||||
Args: | ||||||||||
model: Model name. | ||||||||||
tensor_parallel_size: Number of tensor parallel workers. | ||||||||||
pipeline_parallel_size: Number of pipeline parallel workers. | ||||||||||
enforce_eager: Whether to enforce eager mode. | ||||||||||
vllm_args: vLLM engine args. | ||||||||||
EngineConfig extends EngineArgs with worker-specific fields. | ||||||||||
Overlapping keys in input dict will override EngineArgs defaults. | ||||||||||
|
EngineConfig extends EngineArgs with worker-specific fields. | |
Overlapping keys in input dict will override EngineArgs defaults. | |
EngineConfig extends EngineArgs surfacing worker-specific fields. | |
Args of this class override EngineArgs defaults. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read it here that tensor_parallel_size
is under EngineConfig.parallel_config.tensor_parallel_size
. If so, Is this implementation correct? Should the user pass the value like this instead:
policy:
engine_params:
parallel_config:
tensor_parallel_size = 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I comment on this below, what we have is fine since parallel_config doesn't actually exist until create_engine_config
is called
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we allowing Mapping as an input type just to work around the yaml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to change anything here, but worth us thinking about down the line if we should shim this out across the repo (abstration that handles all the class constructions, actors can act on pure python) so that the actor logic is simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable. I agree. Let's not include it in this PR for now.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there is an typing/design quirk that is throughout the repo, that I'm not a huge fan of, but it's out of scope for this PR
The types used during construction isn't necessarily the same type as the field after construction. In this case vllm_args as input can be a dict | EngineArgs
but dict
gets converted to EngineArgs
(to make calling easier with the yaml?), and in post__init
it further gets transformed into a vLLMConfig
Uh oh!
There was an error while loading. Please reload this page.