diff --git a/apps/vllm/main.py b/apps/vllm/main.py new file mode 100644 index 000000000..111a8d5e3 --- /dev/null +++ b/apps/vllm/main.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""To run: + +python -m apps.vllm.main --guided-decoding --num-samples 3 + +""" + +import argparse +import asyncio +from argparse import Namespace +from typing import List + +from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig +from forge.controller.service import ServiceConfig +from forge.controller.spawn import spawn_service +from vllm.outputs import CompletionOutput + + +async def main(): + """Main application for running vLLM policy inference.""" + args = parse_args() + + # Create configuration objects + policy_config, service_config = get_configs(args) + + # Resolve the Prompts + if args.prompt is None: + prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke" + else: + prompt = args.prompt + + # Run the policy + await run_vllm(service_config, policy_config, prompt) + + +def parse_args() -> Namespace: + parser = argparse.ArgumentParser(description="VLLM Policy Inference Application") + parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.1-8B-Instruct", + help="Model to use", + ) + parser.add_argument( + "--num-samples", type=int, default=2, help="Number of samples to generate" + ) + parser.add_argument( + "--guided-decoding", action="store_true", help="Enable guided decoding" + ) + parser.add_argument( + "--prompt", type=str, default=None, help="Custom prompt to use for generation" + ) + return parser.parse_args() + + +def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): + worker_params = WorkerConfig( + model=args.model, + tensor_parallel_size=2, + pipeline_parallel_size=1, + enforce_eager=True, + vllm_args=None, + ) + + sampling_params = SamplingOverrides( + num_samples=args.num_samples, + guided_decoding=args.guided_decoding, + ) + + policy_config = PolicyConfig( + num_workers=2, worker_params=worker_params, sampling_params=sampling_params + ) + service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) + + return policy_config, service_config + + +async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str): + print("Spawning service...") + policy = await spawn_service(service_config, Policy, config=config) + session_id = await policy.start_session() + + print("Starting background processing...") + processing_task = asyncio.create_task(policy.run_processing.call()) + + print("Requesting generation...") + responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) + + print("\nGeneration Results:") + print("=" * 80) + for batch, response in enumerate(responses): + print(f"Sample {batch + 1}:") + print(f"User: {prompt}") + print(f"Assistant: {response.text}") + print("-" * 80) + + print("\nShutting down...") + await policy.shutdown.call() + await policy.terminate_session(session_id) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 20f0744bf..08ea769f8 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -470,49 +470,3 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: # We only care about the final output params.output_kind = RequestOutputKind.FINAL_ONLY return params - - -# TODO: Create proper test -async def _test(config: DictConfig): - prompt = ( - "What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" - ) - service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) - - print("Spawning service") - policy = await spawn_service(service_config, Policy, config=config) - session_id = await policy.start_session() - - print("Kick off background processing") - asyncio.create_task(policy.run_processing.call()) - - print("Request Generation") - responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) - - print("Terminating session") - await policy.shutdown.call() - await policy.terminate_session(session_id) - - for batch, response in enumerate(responses): - print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - print(f"Batch {batch}:") - print(f"User: {prompt}\nAssistant: {response.text}") - print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - - -if __name__ == "__main__": - config = PolicyConfig( - num_workers=2, - worker_params=WorkerConfig( - model="meta-llama/Llama-3.1-8B-Instruct", - tensor_parallel_size=2, - pipeline_parallel_size=1, - enforce_eager=True, - vllm_args=None, - ), - sampling_params=SamplingOverrides( - num_samples=2, - guided_decoding=True, - ), - ) - asyncio.run(_test(config))