From c438c2dc28ddbe5b7aa6ff93fe47d7bd3e29c295 Mon Sep 17 00:00:00 2001 From: Het Trivedi Date: Mon, 1 Jul 2024 08:05:19 -0700 Subject: [PATCH 1/2] Adding vllm spec dec example --- vllm-speculative-decoding/config.yaml | 23 ++++++++ vllm-speculative-decoding/model/__init__.py | 0 vllm-speculative-decoding/model/model.py | 65 +++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 vllm-speculative-decoding/config.yaml create mode 100644 vllm-speculative-decoding/model/__init__.py create mode 100644 vllm-speculative-decoding/model/model.py diff --git a/vllm-speculative-decoding/config.yaml b/vllm-speculative-decoding/config.yaml new file mode 100644 index 000000000..b93ea7dce --- /dev/null +++ b/vllm-speculative-decoding/config.yaml @@ -0,0 +1,23 @@ +base_image: + image: nvcr.io/nvidia/pytorch:23.11-py3 + python_executable_path: /usr/bin/python3 +build_commands: [] +environment_variables: + HF_TOKEN: "" +external_package_dirs: [] +model_metadata: + main_model: meta-llama/Meta-Llama-3-8B-Instruct + assistant_model: ibm-fms/llama3-8b-accelerator + tensor_parallel: 1 + max_num_seqs: 16 +model_name: vLLM Speculative Decoding +python_version: py310 +requirements: + - git+https://github.com/vllm-project/vllm@9def10664e8b54dcc5c6114f2895bc9e712bf182 +resources: + accelerator: A100 + use_gpu: true +system_packages: + - python3.10-venv +runtime: + predict_concurrency: 128 diff --git a/vllm-speculative-decoding/model/__init__.py b/vllm-speculative-decoding/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm-speculative-decoding/model/model.py b/vllm-speculative-decoding/model/model.py new file mode 100644 index 000000000..e684a3af9 --- /dev/null +++ b/vllm-speculative-decoding/model/model.py @@ -0,0 +1,65 @@ +import logging +import subprocess +import uuid + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine + +logger = logging.getLogger(__name__) + + +class Model: + def __init__(self, **kwargs): + self._config = kwargs["config"] + self.model = None + self.llm_engine = None + self.model_args = None + + num_gpus = self._config["model_metadata"]["tensor_parallel"] + logger.info(f"num GPUs ray: {num_gpus}") + command = f"ray start --head --num-gpus={num_gpus}" + subprocess.check_output(command, shell=True, text=True) + + def load(self): + model_metadata = self._config["model_metadata"] + logger.info(f"main model: {model_metadata['main_model']}") + logger.info(f"assistant model: {model_metadata['assistant_model']}") + logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}") + logger.info(f"max num seqs: {model_metadata['max_num_seqs']}") + + self.model_args = AsyncEngineArgs( + model=model_metadata["main_model"], + speculative_model=model_metadata["assistant_model"], + trust_remote_code=True, + tensor_parallel_size=model_metadata["tensor_parallel"], + max_num_seqs=model_metadata["max_num_seqs"], + dtype="half", + use_v2_block_manager=True, + enforce_eager=True, + ) + self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args) + + async def predict(self, model_input): + prompt = model_input.pop("prompt") + stream = model_input.pop("stream", True) + + sampling_params = SamplingParams(**model_input) + idx = str(uuid.uuid4().hex) + vllm_generator = self.llm_engine.generate(prompt, sampling_params, idx) + + async def generator(): + full_text = "" + async for output in vllm_generator: + text = output.outputs[0].text + delta = text[len(full_text) :] + full_text = text + yield delta + + if stream: + return generator() + else: + full_text = "" + async for delta in generator(): + full_text += delta + return {"text": full_text} From fd5be0cd109695d526df7dfd5f2630e2d605b473 Mon Sep 17 00:00:00 2001 From: Het Trivedi Date: Mon, 1 Jul 2024 14:36:17 -0700 Subject: [PATCH 2/2] minor fixes --- vllm-speculative-decoding/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm-speculative-decoding/config.yaml b/vllm-speculative-decoding/config.yaml index b93ea7dce..f9de8f2b9 100644 --- a/vllm-speculative-decoding/config.yaml +++ b/vllm-speculative-decoding/config.yaml @@ -1,3 +1,4 @@ +# This base image is required for developer build of vLLM base_image: image: nvcr.io/nvidia/pytorch:23.11-py3 python_executable_path: /usr/bin/python3 @@ -11,7 +12,6 @@ model_metadata: tensor_parallel: 1 max_num_seqs: 16 model_name: vLLM Speculative Decoding -python_version: py310 requirements: - git+https://github.com/vllm-project/vllm@9def10664e8b54dcc5c6114f2895bc9e712bf182 resources: