diff --git a/vllm-speculative-decoding/config.yaml b/vllm-speculative-decoding/config.yaml new file mode 100644 index 000000000..f9de8f2b9 --- /dev/null +++ b/vllm-speculative-decoding/config.yaml @@ -0,0 +1,23 @@ +# This base image is required for developer build of vLLM +base_image: + image: nvcr.io/nvidia/pytorch:23.11-py3 + python_executable_path: /usr/bin/python3 +build_commands: [] +environment_variables: + HF_TOKEN: "" +external_package_dirs: [] +model_metadata: + main_model: meta-llama/Meta-Llama-3-8B-Instruct + assistant_model: ibm-fms/llama3-8b-accelerator + tensor_parallel: 1 + max_num_seqs: 16 +model_name: vLLM Speculative Decoding +requirements: + - git+https://github.com/vllm-project/vllm@9def10664e8b54dcc5c6114f2895bc9e712bf182 +resources: + accelerator: A100 + use_gpu: true +system_packages: + - python3.10-venv +runtime: + predict_concurrency: 128 diff --git a/vllm-speculative-decoding/model/__init__.py b/vllm-speculative-decoding/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm-speculative-decoding/model/model.py b/vllm-speculative-decoding/model/model.py new file mode 100644 index 000000000..e684a3af9 --- /dev/null +++ b/vllm-speculative-decoding/model/model.py @@ -0,0 +1,65 @@ +import logging +import subprocess +import uuid + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine + +logger = logging.getLogger(__name__) + + +class Model: + def __init__(self, **kwargs): + self._config = kwargs["config"] + self.model = None + self.llm_engine = None + self.model_args = None + + num_gpus = self._config["model_metadata"]["tensor_parallel"] + logger.info(f"num GPUs ray: {num_gpus}") + command = f"ray start --head --num-gpus={num_gpus}" + subprocess.check_output(command, shell=True, text=True) + + def load(self): + model_metadata = self._config["model_metadata"] + logger.info(f"main model: {model_metadata['main_model']}") + logger.info(f"assistant model: {model_metadata['assistant_model']}") + logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}") + logger.info(f"max num seqs: {model_metadata['max_num_seqs']}") + + self.model_args = AsyncEngineArgs( + model=model_metadata["main_model"], + speculative_model=model_metadata["assistant_model"], + trust_remote_code=True, + tensor_parallel_size=model_metadata["tensor_parallel"], + max_num_seqs=model_metadata["max_num_seqs"], + dtype="half", + use_v2_block_manager=True, + enforce_eager=True, + ) + self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args) + + async def predict(self, model_input): + prompt = model_input.pop("prompt") + stream = model_input.pop("stream", True) + + sampling_params = SamplingParams(**model_input) + idx = str(uuid.uuid4().hex) + vllm_generator = self.llm_engine.generate(prompt, sampling_params, idx) + + async def generator(): + full_text = "" + async for output in vllm_generator: + text = output.outputs[0].text + delta = text[len(full_text) :] + full_text = text + yield delta + + if stream: + return generator() + else: + full_text = "" + async for delta in generator(): + full_text += delta + return {"text": full_text}