diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 9da85e631..6f98a512e 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -13,11 +13,10 @@ import argparse import asyncio from argparse import Namespace -from typing import List from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.controller.service import ServiceConfig, shutdown_service, spawn_service -from vllm.outputs import CompletionOutput +from vllm.outputs import RequestOutput async def main(): @@ -89,11 +88,11 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: async with policy.session(): print("Requesting generation...") - responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) + response_output: RequestOutput = await policy.generate.choose(prompt=prompt) print("\nGeneration Results:") print("=" * 80) - for batch, response in enumerate(responses): + for batch, response in enumerate(response_output.outputs): print(f"Sample {batch + 1}:") print(f"User: {prompt}") print(f"Assistant: {response.text}")