Skip to content

Commit d7f0c35

Browse files
authored
Add Structured Output Support to Policy (#46)
1 parent 6db9d25 commit d7f0c35

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

src/forge/actors/policy.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import sys
1111
from dataclasses import dataclass
12+
from typing import Dict
1213

1314
import torch
1415
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
@@ -18,7 +19,7 @@
1819
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
1920
from vllm.inputs import TextPrompt, TokensPrompt
2021
from vllm.lora.request import LoRARequest
21-
from vllm.sampling_params import RequestOutputKind, SamplingParams
22+
from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams
2223
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
2324
from vllm.usage.usage_lib import UsageContext
2425
from vllm.utils import get_distributed_init_method, get_loopback_ip, get_open_port
@@ -82,7 +83,8 @@ async def generate(self, prompt: str, priority: int = 0):
8283
self.request_id += 1 % sys.maxsize
8384
request_id = str(self.request_id) # implement from a counter
8485

85-
prompt = convert_input(prompt)
86+
# Wraps prompt into a dict
87+
prompt: Dict[str, str] = convert_input(prompt)
8688
if self.sampling_params is None:
8789
self.sampling_params = get_default_sampling_params(self.vllm_args)
8890

@@ -110,8 +112,14 @@ async def generate(self, prompt: str, priority: int = 0):
110112

111113
if self.sampling_params.n == 1:
112114
self.output_processor.add_request(request, prompt_str, None, 0)
113-
request = Request.from_engine_core_request(request)
114-
# TODO: mm_hash and sturcutured_output
115+
116+
if request.mm_hashes is not None:
117+
# TODO: Support mm_hash
118+
pass
119+
request: Request = Request.from_engine_core_request(request)
120+
if request.use_structured_output:
121+
self.scheduler.structured_output_manager.grammar_init(request)
122+
115123
request_fut = asyncio.Future()
116124
self.requests[request_id] = request_fut
117125
self.scheduler.add_request(request)
@@ -284,17 +292,14 @@ def setup_worker(self):
284292
return worker
285293

286294

287-
def convert_input(prompt=None, prompt_token_ids=None):
288-
assert prompt is None or prompt_token_ids is None
295+
def convert_input(prompt=None, prompt_token_ids=None) -> Dict:
296+
assert (prompt is None) ^ (prompt_token_ids is None)
289297
if prompt is not None:
290298
return {"prompt": prompt}
291-
elif prompt_token_ids is not None:
292-
return {"prompt_token_ids": prompt_token_ids}
293-
else:
294-
raise ValueError("Either prompt or prompt_token_ids must be provided.")
299+
return {"prompt_token_ids": prompt_token_ids}
295300

296301

297-
def get_default_sampling_params(vllm_config, overrides=None):
302+
def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
298303
default_params = vllm_config.model_config.get_diff_sampling_param()
299304
default_params["max_tokens"] = 512
300305
if overrides is not None:
@@ -308,7 +313,7 @@ def get_default_sampling_params(vllm_config, overrides=None):
308313
return params
309314

310315

311-
async def _test(config):
316+
async def _test(config, guided_decoding=False):
312317
# TODO: Create proper test
313318
router_mesh = await proc_mesh(gpus=1)
314319
policy_mesh = await proc_mesh(
@@ -320,7 +325,22 @@ async def _test(config):
320325
)
321326

322327
policy_actor = await policy_mesh.spawn("policy", Policy, **config)
323-
router = await router_mesh.spawn("policy_router", PolicyRouter, policy=policy_actor)
328+
329+
sampling_params = None
330+
if guided_decoding:
331+
# Add config for structured output
332+
vllm_args = await policy_actor.get_vllm_args.choose()
333+
guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"])
334+
335+
sampling_params = get_default_sampling_params(vllm_args)
336+
sampling_params.guided_decoding = guided_decoding_params
337+
338+
router = await router_mesh.spawn(
339+
"policy_router",
340+
PolicyRouter,
341+
policy=policy_actor,
342+
sampling_params=sampling_params,
343+
)
324344

325345
await policy_actor.setup.call()
326346
await router.setup.call()
@@ -329,7 +349,7 @@ async def _test(config):
329349
router.run.call()
330350
print("Model running")
331351

332-
prompt = "Tell me a joke"
352+
prompt = "What is 3+5?" if guided_decoding else "Tell me a joke"
333353
response = await router.generate.call_one(prompt)
334354
print(f"User: {prompt}\nAssistant: {response.outputs[0].text}")
335355

@@ -345,3 +365,4 @@ async def _test(config):
345365
"resources": 2,
346366
}
347367
asyncio.run(_test(config))
368+
# asyncio.run(_test(config, guided_decoding=True))

0 commit comments

Comments
 (0)