Skip to content

Commit 7301e10

Browse files
committed
Add explicit from_dict methods for PolicyConfig and WorkerConfig
1 parent d8d775a commit 7301e10

File tree

3 files changed

+77
-73
lines changed

3 files changed

+77
-73
lines changed

apps/vllm/config.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
policy_config:
2+
worker_params:
3+
model: "meta-llama/Llama-3.1-8B-Instruct"
4+
tensor_parallel_size: 2
5+
pipeline_parallel_size: 1
6+
enforce_eager: true
7+
vllm_args: null
8+
sampling_params:
9+
num_samples: 2
10+
guided_decoding: false
11+
available_devices: null
12+
13+
service_config:
14+
procs_per_replica: 2
15+
num_replicas: 1
16+
with_gpus: true
17+
18+
# Optional, otherwise argparse fallback kicks in
19+
prompt: "Tell me a joke"

apps/vllm/main.py

Lines changed: 35 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,81 +6,34 @@
66

77
"""To run:
88
export HF_HUB_DISABLE_XET=1
9-
python -m apps.vllm.main --guided-decoding --num-samples 3
10-
9+
python -m apps.vllm.main --config apps/vllm/config.yaml
1110
"""
1211

13-
import argparse
1412
import asyncio
15-
from argparse import Namespace
16-
from typing import List
13+
import sys
14+
from typing import Any
15+
16+
import yaml
1717

18-
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
18+
from forge.actors.policy import Policy, PolicyConfig
1919
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
20-
from vllm.outputs import CompletionOutput
2120

2221

23-
async def main():
24-
"""Main application for running vLLM policy inference."""
25-
args = parse_args()
22+
def load_yaml_config(path: str) -> dict:
23+
with open(path, "r") as f:
24+
return yaml.safe_load(f)
2625

27-
# Create configuration objects
28-
policy_config, service_config = get_configs(args)
2926

30-
# Resolve the Prompts
31-
if args.prompt is None:
32-
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke"
27+
def get_configs(cfg: dict) -> tuple[PolicyConfig, ServiceConfig, str]:
28+
# Instantiate PolicyConfig and ServiceConfig from nested dicts
29+
policy_config = PolicyConfig.from_dict(cfg["policy_config"])
30+
service_config = ServiceConfig(**cfg["service_config"])
31+
if "prompt" in cfg and cfg["prompt"] is not None:
32+
prompt = cfg["prompt"]
3333
else:
34-
prompt = args.prompt
35-
36-
# Run the policy
37-
await run_vllm(service_config, policy_config, prompt)
38-
39-
40-
def parse_args() -> Namespace:
41-
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application")
42-
parser.add_argument(
43-
"--model",
44-
type=str,
45-
default="meta-llama/Llama-3.1-8B-Instruct",
46-
help="Model to use",
47-
)
48-
parser.add_argument(
49-
"--num-samples", type=int, default=2, help="Number of samples to generate"
50-
)
51-
parser.add_argument(
52-
"--guided-decoding", action="store_true", help="Enable guided decoding"
53-
)
54-
parser.add_argument(
55-
"--prompt", type=str, default=None, help="Custom prompt to use for generation"
56-
)
57-
return parser.parse_args()
58-
59-
60-
def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
61-
62-
worker_size = 2
63-
worker_params = WorkerConfig(
64-
model=args.model,
65-
tensor_parallel_size=worker_size,
66-
pipeline_parallel_size=1,
67-
enforce_eager=True,
68-
vllm_args=None,
69-
)
70-
71-
sampling_params = SamplingOverrides(
72-
num_samples=args.num_samples,
73-
guided_decoding=args.guided_decoding,
74-
)
75-
76-
policy_config = PolicyConfig(
77-
worker_params=worker_params, sampling_params=sampling_params
78-
)
79-
service_config = ServiceConfig(
80-
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
81-
)
82-
83-
return policy_config, service_config
34+
gd = policy_config.sampling_params.guided_decoding
35+
prompt = "What is 3+5?" if gd else "Tell me a joke"
36+
return policy_config, service_config, prompt
8437

8538

8639
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
@@ -89,11 +42,11 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
8942

9043
async with policy.session():
9144
print("Requesting generation...")
92-
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)
45+
response_output = await policy.generate.choose(prompt=prompt)
9346

9447
print("\nGeneration Results:")
9548
print("=" * 80)
96-
for batch, response in enumerate(responses):
49+
for batch, response in enumerate(response_output.outputs):
9750
print(f"Sample {batch + 1}:")
9851
print(f"User: {prompt}")
9952
print(f"Assistant: {response.text}")
@@ -104,5 +57,19 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
10457
await shutdown_service(policy)
10558

10659

60+
def main():
61+
import argparse
62+
63+
parser = argparse.ArgumentParser(description="vLLM Policy Inference Application")
64+
parser.add_argument(
65+
"--config", type=str, required=True, help="Path to YAML config file"
66+
)
67+
args = parser.parse_args()
68+
69+
cfg = load_yaml_config(args.config)
70+
policy_config, service_config, prompt = get_configs(cfg)
71+
asyncio.run(run_vllm(service_config, policy_config, prompt))
72+
73+
10774
if __name__ == "__main__":
108-
asyncio.run(main())
75+
sys.exit(main())

src/forge/actors/policy.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class SamplingOverrides:
6161
guided_decoding: Whether to use guided decoding.
6262
"""
6363

64-
num_samples: int
64+
num_samples: int = 1
6565
guided_decoding: bool = False
6666
max_tokens: int = 512
6767

@@ -79,19 +79,35 @@ class WorkerConfig:
7979
vllm_args: vLLM engine args.
8080
"""
8181

82-
model: str
82+
model: str = "meta-llama/Llama-3.1-8B-Instruct"
8383
tensor_parallel_size: int = 1
8484
pipeline_parallel_size: int = 1
8585
enforce_eager: bool = False
86-
vllm_args: EngineArgs = None
86+
vllm_args: EngineArgs = field(default_factory=EngineArgs)
87+
88+
@classmethod
89+
def from_dict(cls, d: dict):
90+
d = dict(d) # copy
91+
if "vllm_args" in d and isinstance(d["vllm_args"], dict):
92+
d["vllm_args"] = EngineArgs(**d["vllm_args"])
93+
return cls(**d)
8794

8895

8996
@dataclass
9097
class PolicyConfig:
91-
worker_params: WorkerConfig
92-
sampling_params: SamplingOverrides
98+
worker_params: WorkerConfig = field(default_factory=WorkerConfig)
99+
sampling_params: SamplingOverrides = field(default_factory=SamplingOverrides)
93100
available_devices: str = None
94101

102+
@classmethod
103+
def from_dict(cls, d: dict):
104+
d = dict(d)
105+
if "worker_params" in d and isinstance(d["worker_params"], dict):
106+
d["worker_params"] = WorkerConfig.from_dict(d["worker_params"])
107+
if "sampling_params" in d and isinstance(d["sampling_params"], dict):
108+
d["sampling_params"] = SamplingOverrides(**d["sampling_params"])
109+
return cls(**d)
110+
95111

96112
@dataclass
97113
class Policy(PolicyInterface):
@@ -108,6 +124,8 @@ def __post_init__(self):
108124
self._policy_proc: ProcMesh | None = None
109125
self._worker_procs: ProcMesh | None = None
110126
self.weights_version: int = 0
127+
if isinstance(self.config, dict):
128+
self.config = PolicyConfig.from_dict(self.config)
111129

112130
@classmethod
113131
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]

0 commit comments

Comments
 (0)