-
Notifications
You must be signed in to change notification settings - Fork 17
Add YAML-based configuration support for vLLM main #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
da21e1d
5c72908
b4d7a61
02d77c6
fd1d38b
f79beee
d8d775a
7301e10
64687d9
9278d75
d2d7107
14b5e4a
38f7927
2a1e021
412398c
d998061
8d38eb8
a3e755d
9dd396b
935fdc1
35fd71e
187a65d
0d26242
ba74b43
063afe6
a85f7b1
d94d326
5815656
2a31156
4778336
cb42997
eab380a
d575409
a72f4de
b19fe24
fc809f8
6ca7c2b
f1c24fb
4dc2e89
00c7fc9
4445624
a7dfd02
1ed76c4
c38685f
4191fa6
327828b
fe9acae
23e5ef6
7b904fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
policy_config: | ||
worker_params: | ||
model: "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: 2 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: true | ||
vllm_args: null | ||
sampling_params: | ||
num_samples: 2 | ||
guided_decoding: false | ||
available_devices: null | ||
|
||
service_config: | ||
procs_per_replica: 2 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Optional, otherwise argparse fallback kicks in | ||
prompt: "Tell me a joke" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,7 @@ class SamplingOverrides: | |
guided_decoding: Whether to use guided decoding. | ||
""" | ||
|
||
num_samples: int | ||
num_samples: int = 1 | ||
|
||
guided_decoding: bool = False | ||
max_tokens: int = 512 | ||
|
||
|
@@ -79,19 +79,35 @@ class WorkerConfig: | |
vllm_args: vLLM engine args. | ||
""" | ||
|
||
model: str | ||
model: str = "meta-llama/Llama-3.1-8B-Instruct" | ||
tensor_parallel_size: int = 1 | ||
|
||
pipeline_parallel_size: int = 1 | ||
enforce_eager: bool = False | ||
vllm_args: EngineArgs = None | ||
vllm_args: EngineArgs = field(default_factory=EngineArgs) | ||
|
||
@classmethod | ||
def from_dict(cls, d: dict): | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
d = dict(d) # copy | ||
if "vllm_args" in d and isinstance(d["vllm_args"], dict): | ||
d["vllm_args"] = EngineArgs(**d["vllm_args"]) | ||
return cls(**d) | ||
|
||
|
||
@dataclass | ||
class PolicyConfig: | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
worker_params: WorkerConfig | ||
sampling_params: SamplingOverrides | ||
worker_params: WorkerConfig = field(default_factory=WorkerConfig) | ||
sampling_params: SamplingOverrides = field(default_factory=SamplingOverrides) | ||
available_devices: str = None | ||
|
||
@classmethod | ||
def from_dict(cls, d: dict): | ||
d = dict(d) | ||
if "worker_params" in d and isinstance(d["worker_params"], dict): | ||
d["worker_params"] = WorkerConfig.from_dict(d["worker_params"]) | ||
if "sampling_params" in d and isinstance(d["sampling_params"], dict): | ||
d["sampling_params"] = SamplingOverrides(**d["sampling_params"]) | ||
return cls(**d) | ||
|
||
|
||
@dataclass | ||
class Policy(PolicyInterface): | ||
|
@@ -108,6 +124,8 @@ def __post_init__(self): | |
self._policy_proc: ProcMesh | None = None | ||
self._worker_procs: ProcMesh | None = None | ||
self.weights_version: int = 0 | ||
if isinstance(self.config, dict): | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self.config = PolicyConfig.from_dict(self.config) | ||
|
||
@classmethod | ||
async def launch( # pyright: ignore[reportIncompatibleMethodOverride] | ||
|
Uh oh!
There was an error while loading. Please reload this page.