Skip to content

Commit 9278d75

Browse files
committed
remove policy config
1 parent 64687d9 commit 9278d75

File tree

4 files changed

+172
-55
lines changed

4 files changed

+172
-55
lines changed

apps/vllm/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
policy_config:
1+
policy:
22
worker_params:
33
model: "meta-llama/Llama-3.1-8B-Instruct"
44
tensor_parallel_size: 2

apps/vllm/main.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,31 @@
1111

1212
import asyncio
1313
import sys
14-
from typing import Any
1514

16-
import yaml
17-
18-
from forge.actors.policy import Policy, PolicyConfig
15+
from forge.actors.policy import Policy
16+
from forge.cli.config import parse
1917
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
18+
from omegaconf import DictConfig
19+
from vllm.outputs import RequestOutput
2020

2121

22-
def load_yaml_config(path: str) -> dict:
23-
with open(path, "r") as f:
24-
return yaml.safe_load(f)
25-
22+
async def run(cfg: DictConfig):
2623

27-
def get_configs(cfg: dict) -> tuple[PolicyConfig, ServiceConfig, str]:
28-
# Instantiate PolicyConfig and ServiceConfig from nested dicts
29-
policy_config = PolicyConfig.from_dict(cfg["policy_config"])
30-
service_config = ServiceConfig(**cfg["service_config"])
3124
if "prompt" in cfg and cfg["prompt"] is not None:
3225
prompt = cfg["prompt"]
3326
else:
34-
gd = policy_config.sampling_params.guided_decoding
27+
gd = cfg.policy.get("sampling_params", {}).get("guided_decoding", False)
3528
prompt = "What is 3+5?" if gd else "Tell me a joke"
36-
return policy_config, service_config, prompt
3729

38-
39-
async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
4030
print("Spawning service...")
41-
policy = await spawn_service(service_config, Policy, config=config)
31+
32+
policy = await spawn_service(
33+
ServiceConfig(**cfg.service_config), Policy, **cfg.policy
34+
)
4235

4336
async with policy.session():
4437
print("Requesting generation...")
45-
response_output = await policy.generate.choose(prompt=prompt)
38+
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)
4639

4740
print("\nGeneration Results:")
4841
print("=" * 80)
@@ -57,19 +50,10 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
5750
await shutdown_service(policy)
5851

5952

60-
def main():
61-
import argparse
62-
63-
parser = argparse.ArgumentParser(description="vLLM Policy Inference Application")
64-
parser.add_argument(
65-
"--config", type=str, required=True, help="Path to YAML config file"
66-
)
67-
args = parser.parse_args()
68-
69-
cfg = load_yaml_config(args.config)
70-
policy_config, service_config, prompt = get_configs(cfg)
71-
asyncio.run(run_vllm(service_config, policy_config, prompt))
53+
@parse
54+
def recipe_main(cfg: DictConfig) -> None:
55+
asyncio.run(run(cfg))
7256

7357

7458
if __name__ == "__main__":
75-
sys.exit(main())
59+
sys.exit(recipe_main())

src/forge/actors/policy.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from forge.interfaces import Policy as PolicyInterface
2121
from forge.types import ProcessConfig
2222
from monarch.actor import current_rank, endpoint, ProcMesh
23+
from omegaconf import DictConfig
2324
from torchstore import MultiProcessStore
2425
from torchstore._state_dict_utils import DELIM
2526

@@ -90,28 +91,16 @@ def from_dict(cls, d: dict):
9091
d = dict(d) # copy
9192
if "vllm_args" in d and isinstance(d["vllm_args"], dict):
9293
d["vllm_args"] = EngineArgs(**d["vllm_args"])
93-
return cls(**d)
94-
95-
96-
@dataclass
97-
class PolicyConfig:
98-
worker_params: WorkerConfig = field(default_factory=WorkerConfig)
99-
sampling_params: SamplingOverrides = field(default_factory=SamplingOverrides)
100-
available_devices: str = None
101-
102-
@classmethod
103-
def from_dict(cls, d: dict):
104-
d = dict(d)
105-
if "worker_params" in d and isinstance(d["worker_params"], dict):
106-
d["worker_params"] = WorkerConfig.from_dict(d["worker_params"])
107-
if "sampling_params" in d and isinstance(d["sampling_params"], dict):
108-
d["sampling_params"] = SamplingOverrides(**d["sampling_params"])
94+
else:
95+
d["vllm_args"] = None
10996
return cls(**d)
11097

11198

11299
@dataclass
113100
class Policy(PolicyInterface):
114-
config: PolicyConfig
101+
worker_params: WorkerConfig = field(default_factory=WorkerConfig)
102+
sampling_overrides: SamplingOverrides = field(default_factory=SamplingOverrides)
103+
available_devices: str | None = None
115104
# Gets set up by setup
116105
sampling_params: SamplingParams | None = None
117106
lora_request: LoRARequest | None = None
@@ -124,15 +113,19 @@ def __post_init__(self):
124113
self._policy_proc: ProcMesh | None = None
125114
self._worker_procs: ProcMesh | None = None
126115
self.weights_version: int = 0
127-
if isinstance(self.config, dict):
128-
self.config = PolicyConfig.from_dict(self.config)
116+
if isinstance(self.worker_params, dict):
117+
self.worker_params = WorkerConfig.from_dict(self.worker_params)
118+
if isinstance(self.sampling_overrides, dict):
119+
self.sampling_overrides = SamplingOverrides(**self.sampling_overrides)
129120

130121
@classmethod
131122
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
132123
cls: type["Policy"],
133124
*,
134125
process_config: ProcessConfig,
135-
config: PolicyConfig,
126+
worker_params: WorkerConfig | dict = WorkerConfig(),
127+
sampling_overrides: SamplingOverrides | dict = SamplingOverrides(),
128+
available_devices: str | None = None,
136129
store: MultiProcessStore | None = None,
137130
**kwargs,
138131
) -> "Policy":
@@ -146,16 +139,25 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
146139
policy_proc_config.with_gpus = False
147140

148141
policy_proc = await get_proc_mesh(process_config=policy_proc_config)
142+
143+
if isinstance(worker_params, (dict, DictConfig)):
144+
worker_params = WorkerConfig.from_dict(worker_params)
145+
146+
if isinstance(worker_params, (dict, DictConfig)):
147+
sampling_overrides = SamplingOverrides(**sampling_overrides)
148+
149149
workers = await worker_procs.spawn(
150-
"vllm_worker", PolicyWorker, **asdict(config.worker_params)
150+
"vllm_worker", PolicyWorker, **asdict(worker_params)
151151
)
152152

153153
# TODO - expand support so name can stick within kwargs
154154
actor_name = kwargs.pop("name", cls.__name__)
155155
policy = await policy_proc.spawn(
156156
actor_name,
157157
cls,
158-
config=config,
158+
worker_params=worker_params,
159+
sampling_overrides=sampling_overrides,
160+
available_devices=available_devices,
159161
policy_worker=workers,
160162
store=store,
161163
)
@@ -192,7 +194,7 @@ async def setup(self):
192194
self.vllm_args = await self.policy_worker.get_vllm_args.choose()
193195

194196
# Setup sampling params
195-
sampling_overrides = self.config.sampling_params
197+
sampling_overrides = self.sampling_overrides
196198
overrides = {
197199
"n": sampling_overrides.num_samples,
198200
"guided_decoding": (
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import inspect
8+
import tempfile
9+
import unittest
10+
from dataclasses import asdict
11+
12+
import yaml
13+
14+
from forge.actors.policy import Policy, SamplingOverrides, WorkerConfig
15+
from vllm.engine.arg_utils import EngineArgs
16+
17+
18+
class TestPolicyConfig(unittest.TestCase):
19+
"""Test suite for Policy configuration handling after PolicyConfig removal."""
20+
21+
def test_policy_default_initialization(self):
22+
"""Test that Policy can be initialized with default values."""
23+
policy = Policy()
24+
25+
# Check that default factories work
26+
self.assertIsInstance(policy.worker_params, WorkerConfig)
27+
self.assertIsInstance(policy.sampling_overrides, SamplingOverrides)
28+
self.assertIsNone(policy.available_devices)
29+
30+
# Check default values
31+
self.assertEqual(policy.worker_params.model, "meta-llama/Llama-3.1-8B-Instruct")
32+
self.assertEqual(policy.worker_params.tensor_parallel_size, 1)
33+
self.assertEqual(policy.worker_params.pipeline_parallel_size, 1)
34+
self.assertFalse(policy.worker_params.enforce_eager)
35+
36+
self.assertEqual(policy.sampling_overrides.num_samples, 1)
37+
self.assertFalse(policy.sampling_overrides.guided_decoding)
38+
self.assertEqual(policy.sampling_overrides.max_tokens, 512)
39+
40+
def test_policy_with_dict_configs(self):
41+
"""Test Policy initialization with dictionary configs."""
42+
worker_dict = {
43+
"model": "test-model-6789",
44+
"tensor_parallel_size": 7777,
45+
"pipeline_parallel_size": 8888,
46+
"enforce_eager": True,
47+
"vllm_args": {"max_model_len": 9999, "gpu_memory_utilization": 0.1234},
48+
}
49+
50+
sampling_dict = {
51+
"num_samples": 1357,
52+
"guided_decoding": True,
53+
"max_tokens": 2468,
54+
}
55+
56+
policy = Policy(
57+
worker_params=worker_dict,
58+
sampling_overrides=sampling_dict,
59+
available_devices="test-gpu-device-abcd",
60+
)
61+
62+
# Check that dictionaries were converted to proper objects
63+
self.assertIsInstance(policy.worker_params, WorkerConfig)
64+
self.assertIsInstance(policy.sampling_overrides, SamplingOverrides)
65+
66+
self.assertEqual(policy.worker_params.model, "test-model-6789")
67+
self.assertEqual(policy.worker_params.tensor_parallel_size, 7777)
68+
self.assertEqual(policy.worker_params.pipeline_parallel_size, 8888)
69+
self.assertTrue(policy.worker_params.enforce_eager)
70+
71+
self.assertEqual(policy.sampling_overrides.num_samples, 1357)
72+
self.assertTrue(policy.sampling_overrides.guided_decoding)
73+
self.assertEqual(policy.sampling_overrides.max_tokens, 2468)
74+
75+
def test_policy_yaml_config_loading(self):
76+
"""Test loading Policy configuration from YAML file."""
77+
yaml_content = """
78+
worker_params:
79+
model: "yaml-test-model-9876"
80+
tensor_parallel_size: 1234
81+
pipeline_parallel_size: 5678
82+
enforce_eager: true
83+
vllm_args:
84+
max_model_len: 9876
85+
gpu_memory_utilization: 0.1357
86+
87+
sampling_overrides:
88+
num_samples: 2468
89+
guided_decoding: true
90+
max_tokens: 1357
91+
92+
available_devices: "yaml-test-device-xyz"
93+
"""
94+
95+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
96+
f.write(yaml_content)
97+
f.flush()
98+
99+
# Load YAML and create Policy
100+
with open(f.name, "r") as yaml_file:
101+
config = yaml.safe_load(yaml_file)
102+
103+
policy = Policy(**config)
104+
105+
self.assertEqual(policy.worker_params.model, "yaml-test-model-9876")
106+
self.assertEqual(policy.worker_params.tensor_parallel_size, 1234)
107+
self.assertEqual(policy.worker_params.pipeline_parallel_size, 5678)
108+
self.assertTrue(policy.worker_params.enforce_eager)
109+
110+
self.assertEqual(policy.sampling_overrides.num_samples, 2468)
111+
self.assertTrue(policy.sampling_overrides.guided_decoding)
112+
self.assertEqual(policy.sampling_overrides.max_tokens, 1357)
113+
114+
self.assertEqual(policy.available_devices, "yaml-test-device-xyz")
115+
116+
def test_invalid_worker_config_from_dict(self):
117+
"""Test that WorkerConfig.from_dict handles invalid vllm_args gracefully."""
118+
config_dict = {
119+
"model": "meta-llama/Llama-3.1-8B-Instruct",
120+
"vllm_args": "invalid_string_instead_of_dict", # This will be passed through as-is
121+
}
122+
123+
worker_config = WorkerConfig.from_dict(config_dict)
124+
125+
# The invalid vllm_args gets removed and default EngineArgs is used
126+
self.assertEqual(worker_config.model, "meta-llama/Llama-3.1-8B-Instruct")
127+
self.assertEqual(worker_config.vllm_args, None)
128+
129+
130+
if __name__ == "__main__":
131+
unittest.main()

0 commit comments

Comments
 (0)