Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""To run:

python -m apps.vllm.main --guided-decoding --num-samples 3

"""

import argparse
import asyncio
from argparse import Namespace
from typing import List

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.controller.service import ServiceConfig
from forge.controller.spawn import spawn_service
from vllm.outputs import CompletionOutput


async def main():
"""Main application for running vLLM policy inference."""
args = parse_args()

# Create configuration objects
policy_config, service_config = get_configs(args)

# Resolve the Prompts
if args.prompt is None:
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke"
else:
prompt = args.prompt

# Run the policy
await run_vllm(service_config, policy_config, prompt)


def parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application")
parser.add_argument(
"--model",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="Model to use",
)
parser.add_argument(
"--num-samples", type=int, default=2, help="Number of samples to generate"
)
parser.add_argument(
"--guided-decoding", action="store_true", help="Enable guided decoding"
)
parser.add_argument(
"--prompt", type=str, default=None, help="Custom prompt to use for generation"
)
return parser.parse_args()


def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
worker_params = WorkerConfig(
model=args.model,
tensor_parallel_size=2,
pipeline_parallel_size=1,
enforce_eager=True,
vllm_args=None,
)

sampling_params = SamplingOverrides(
num_samples=args.num_samples,
guided_decoding=args.guided_decoding,
)

policy_config = PolicyConfig(
num_workers=2, worker_params=worker_params, sampling_params=sampling_params
)
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1)

return policy_config, service_config


async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
print("Spawning service...")
policy = await spawn_service(service_config, Policy, config=config)
session_id = await policy.start_session()

print("Starting background processing...")
processing_task = asyncio.create_task(policy.run_processing.call())

print("Requesting generation...")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)

print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(responses):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)

print("\nShutting down...")
await policy.shutdown.call()
await policy.terminate_session(session_id)


if __name__ == "__main__":
asyncio.run(main())
46 changes: 0 additions & 46 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,49 +470,3 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
# We only care about the final output
params.output_kind = RequestOutputKind.FINAL_ONLY
return params


# TODO: Create proper test
async def _test(config: DictConfig):
prompt = (
"What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke"
)
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1)

print("Spawning service")
policy = await spawn_service(service_config, Policy, config=config)
session_id = await policy.start_session()

print("Kick off background processing")
asyncio.create_task(policy.run_processing.call())

print("Request Generation")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)

print("Terminating session")
await policy.shutdown.call()
await policy.terminate_session(session_id)

for batch, response in enumerate(responses):
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print(f"Batch {batch}:")
print(f"User: {prompt}\nAssistant: {response.text}")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")


if __name__ == "__main__":
config = PolicyConfig(
num_workers=2,
worker_params=WorkerConfig(
model="meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=2,
pipeline_parallel_size=1,
enforce_eager=True,
vllm_args=None,
),
sampling_params=SamplingOverrides(
num_samples=2,
guided_decoding=True,
),
)
asyncio.run(_test(config))
Loading