66
77"""To run:
88export HF_HUB_DISABLE_XET=1
9- python -m apps.vllm.main --guided-decoding --num-samples 3
10-
9+ python -m apps.vllm.main --config apps/vllm/config.yaml
1110"""
1211
13- import argparse
1412import asyncio
15- from argparse import Namespace
16- from typing import List
13+ import sys
14+ from typing import Any
15+
16+ import yaml
1717
18- from forge .actors .policy import Policy , PolicyConfig , SamplingOverrides , WorkerConfig
18+ from forge .actors .policy import Policy , PolicyConfig
1919from forge .controller .service import ServiceConfig , shutdown_service , spawn_service
20- from vllm .outputs import CompletionOutput
2120
2221
23- async def main () :
24- """Main application for running vLLM policy inference."""
25- args = parse_args ( )
22+ def load_yaml_config ( path : str ) -> dict :
23+ with open ( path , "r" ) as f :
24+ return yaml . safe_load ( f )
2625
27- # Create configuration objects
28- policy_config , service_config = get_configs (args )
2926
30- # Resolve the Prompts
31- if args .prompt is None :
32- prompt = "What is 3+5?" if args .guided_decoding else "Tell me a joke"
27+ def get_configs (cfg : dict ) -> tuple [PolicyConfig , ServiceConfig , str ]:
28+ # Instantiate PolicyConfig and ServiceConfig from nested dicts
29+ policy_config = PolicyConfig .from_dict (cfg ["policy_config" ])
30+ service_config = ServiceConfig (** cfg ["service_config" ])
31+ if "prompt" in cfg and cfg ["prompt" ] is not None :
32+ prompt = cfg ["prompt" ]
3333 else :
34- prompt = args .prompt
35-
36- # Run the policy
37- await run_vllm (service_config , policy_config , prompt )
38-
39-
40- def parse_args () -> Namespace :
41- parser = argparse .ArgumentParser (description = "VLLM Policy Inference Application" )
42- parser .add_argument (
43- "--model" ,
44- type = str ,
45- default = "meta-llama/Llama-3.1-8B-Instruct" ,
46- help = "Model to use" ,
47- )
48- parser .add_argument (
49- "--num-samples" , type = int , default = 2 , help = "Number of samples to generate"
50- )
51- parser .add_argument (
52- "--guided-decoding" , action = "store_true" , help = "Enable guided decoding"
53- )
54- parser .add_argument (
55- "--prompt" , type = str , default = None , help = "Custom prompt to use for generation"
56- )
57- return parser .parse_args ()
58-
59-
60- def get_configs (args : Namespace ) -> (PolicyConfig , ServiceConfig ):
61-
62- worker_size = 2
63- worker_params = WorkerConfig (
64- model = args .model ,
65- tensor_parallel_size = worker_size ,
66- pipeline_parallel_size = 1 ,
67- enforce_eager = True ,
68- vllm_args = None ,
69- )
70-
71- sampling_params = SamplingOverrides (
72- num_samples = args .num_samples ,
73- guided_decoding = args .guided_decoding ,
74- )
75-
76- policy_config = PolicyConfig (
77- worker_params = worker_params , sampling_params = sampling_params
78- )
79- service_config = ServiceConfig (
80- procs_per_replica = worker_size , num_replicas = 1 , with_gpus = True
81- )
82-
83- return policy_config , service_config
34+ gd = policy_config .sampling_params .guided_decoding
35+ prompt = "What is 3+5?" if gd else "Tell me a joke"
36+ return policy_config , service_config , prompt
8437
8538
8639async def run_vllm (service_config : ServiceConfig , config : PolicyConfig , prompt : str ):
@@ -89,11 +42,11 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
8942
9043 async with policy .session ():
9144 print ("Requesting generation..." )
92- responses : List [ CompletionOutput ] = await policy .generate .choose (prompt = prompt )
45+ response_output = await policy .generate .choose (prompt = prompt )
9346
9447 print ("\n Generation Results:" )
9548 print ("=" * 80 )
96- for batch , response in enumerate (responses ):
49+ for batch , response in enumerate (response_output . outputs ):
9750 print (f"Sample { batch + 1 } :" )
9851 print (f"User: { prompt } " )
9952 print (f"Assistant: { response .text } " )
@@ -104,5 +57,19 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
10457 await shutdown_service (policy )
10558
10659
60+ def main ():
61+ import argparse
62+
63+ parser = argparse .ArgumentParser (description = "vLLM Policy Inference Application" )
64+ parser .add_argument (
65+ "--config" , type = str , required = True , help = "Path to YAML config file"
66+ )
67+ args = parser .parse_args ()
68+
69+ cfg = load_yaml_config (args .config )
70+ policy_config , service_config , prompt = get_configs (cfg )
71+ asyncio .run (run_vllm (service_config , policy_config , prompt ))
72+
73+
10774if __name__ == "__main__" :
108- asyncio . run (main ())
75+ sys . exit (main ())
0 commit comments