99import os
1010import sys
1111from dataclasses import dataclass
12+ from typing import Dict
1213
1314import torch
1415from monarch .actor import Actor , current_rank , endpoint , proc_mesh
1819from vllm .executor .multiproc_worker_utils import set_multiprocessing_worker_envs
1920from vllm .inputs import TextPrompt , TokensPrompt
2021from vllm .lora .request import LoRARequest
21- from vllm .sampling_params import RequestOutputKind , SamplingParams
22+ from vllm .sampling_params import GuidedDecodingParams , RequestOutputKind , SamplingParams
2223from vllm .transformers_utils .tokenizer_group import init_tokenizer_from_configs
2324from vllm .usage .usage_lib import UsageContext
2425from vllm .utils import get_distributed_init_method , get_loopback_ip , get_open_port
@@ -82,7 +83,8 @@ async def generate(self, prompt: str, priority: int = 0):
8283 self .request_id += 1 % sys .maxsize
8384 request_id = str (self .request_id ) # implement from a counter
8485
85- prompt = convert_input (prompt )
86+ # Wraps prompt into a dict
87+ prompt : Dict [str , str ] = convert_input (prompt )
8688 if self .sampling_params is None :
8789 self .sampling_params = get_default_sampling_params (self .vllm_args )
8890
@@ -110,8 +112,14 @@ async def generate(self, prompt: str, priority: int = 0):
110112
111113 if self .sampling_params .n == 1 :
112114 self .output_processor .add_request (request , prompt_str , None , 0 )
113- request = Request .from_engine_core_request (request )
114- # TODO: mm_hash and sturcutured_output
115+
116+ if request .mm_hashes is not None :
117+ # TODO: Support mm_hash
118+ pass
119+ request : Request = Request .from_engine_core_request (request )
120+ if request .use_structured_output :
121+ self .scheduler .structured_output_manager .grammar_init (request )
122+
115123 request_fut = asyncio .Future ()
116124 self .requests [request_id ] = request_fut
117125 self .scheduler .add_request (request )
@@ -284,17 +292,14 @@ def setup_worker(self):
284292 return worker
285293
286294
287- def convert_input (prompt = None , prompt_token_ids = None ):
288- assert prompt is None or prompt_token_ids is None
295+ def convert_input (prompt = None , prompt_token_ids = None ) -> Dict :
296+ assert ( prompt is None ) ^ ( prompt_token_ids is None )
289297 if prompt is not None :
290298 return {"prompt" : prompt }
291- elif prompt_token_ids is not None :
292- return {"prompt_token_ids" : prompt_token_ids }
293- else :
294- raise ValueError ("Either prompt or prompt_token_ids must be provided." )
299+ return {"prompt_token_ids" : prompt_token_ids }
295300
296301
297- def get_default_sampling_params (vllm_config , overrides = None ):
302+ def get_default_sampling_params (vllm_config , overrides = None ) -> SamplingParams :
298303 default_params = vllm_config .model_config .get_diff_sampling_param ()
299304 default_params ["max_tokens" ] = 512
300305 if overrides is not None :
@@ -308,7 +313,7 @@ def get_default_sampling_params(vllm_config, overrides=None):
308313 return params
309314
310315
311- async def _test (config ):
316+ async def _test (config , guided_decoding = False ):
312317 # TODO: Create proper test
313318 router_mesh = await proc_mesh (gpus = 1 )
314319 policy_mesh = await proc_mesh (
@@ -320,7 +325,22 @@ async def _test(config):
320325 )
321326
322327 policy_actor = await policy_mesh .spawn ("policy" , Policy , ** config )
323- router = await router_mesh .spawn ("policy_router" , PolicyRouter , policy = policy_actor )
328+
329+ sampling_params = None
330+ if guided_decoding :
331+ # Add config for structured output
332+ vllm_args = await policy_actor .get_vllm_args .choose ()
333+ guided_decoding_params = GuidedDecodingParams (choice = ["Positive" , "Negative" ])
334+
335+ sampling_params = get_default_sampling_params (vllm_args )
336+ sampling_params .guided_decoding = guided_decoding_params
337+
338+ router = await router_mesh .spawn (
339+ "policy_router" ,
340+ PolicyRouter ,
341+ policy = policy_actor ,
342+ sampling_params = sampling_params ,
343+ )
324344
325345 await policy_actor .setup .call ()
326346 await router .setup .call ()
@@ -329,7 +349,7 @@ async def _test(config):
329349 router .run .call ()
330350 print ("Model running" )
331351
332- prompt = "Tell me a joke"
352+ prompt = "What is 3+5?" if guided_decoding else " Tell me a joke"
333353 response = await router .generate .call_one (prompt )
334354 print (f"User: { prompt } \n Assistant: { response .outputs [0 ].text } " )
335355
@@ -345,3 +365,4 @@ async def _test(config):
345365 "resources" : 2 ,
346366 }
347367 asyncio .run (_test (config ))
368+ # asyncio.run(_test(config, guided_decoding=True))
0 commit comments