diff --git a/apps/vllm/llama3_8b.yaml b/apps/vllm/llama3_8b.yaml index 2134d7f55..7907dad88 100644 --- a/apps/vllm/llama3_8b.yaml +++ b/apps/vllm/llama3_8b.yaml @@ -11,8 +11,8 @@ policy: services: policy: - procs: 2 - num_replicas: 1 + procs: ${policy.engine_config.tensor_parallel_size} + num_replicas: 4 with_gpus: true diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 8789cbe29..6ba1bbbaf 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -33,25 +33,32 @@ async def run(cfg: DictConfig): print("Spawning service...") policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy) - try: - async with policy.session(): - print("Requesting generation...") - response_output: list[Completion] = await policy.generate.choose( - prompt=prompt - ) - - print("\nGeneration Results:") - print("=" * 80) - for batch, response in enumerate(response_output): - print(f"Sample {batch + 1}:") - print(f"User: {prompt}") - print(f"Assistant: {response.text}") - print("-" * 80) - - finally: - print("\nShutting down...") - await policy.shutdown() - await shutdown() + import time + + print("Requesting generation...") + n = 100 + start = time.time() + response_outputs: list[Completion] = await asyncio.gather( + *[policy.generate.choose(prompt=prompt) for _ in range(n)] + ) + end = time.time() + + print(f"Generation of {n} requests completed in {end - start:.2f} seconds.") + print( + f"Generation with procs {cfg.services.policy.procs}, replicas {cfg.services.policy.num_replicas}" + ) + + print(f"\nGeneration Results (last one of {n} requests):") + print("=" * 80) + for batch, response in enumerate(response_outputs[-1]): + print(f"Sample {batch + 1}:") + print(f"User: {prompt}") + print(f"Assistant: {response.text}") + print("-" * 80) + + print("\nShutting down...") + await policy.shutdown() + await shutdown() @parse