Skip to content

Commit b496ec2

Browse files
authored
Add Llama 7B vLLM example (#233)
1 parent 5f41259 commit b496ec2

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

llama/llama-7b-vllm/config.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
model_metadata:
2+
engine_args:
3+
model: TheBloke/Llama-2-7B-Chat-fp16
4+
example_model_input:
5+
prompt: Where do Llamas come from?
6+
pretty_name: Llama 2 7B
7+
prompt_format: <s>[INST] {prompt} [/INST]
8+
tags:
9+
- text-generation
10+
model_name: Llama 7B Instruct vLLM
11+
python_version: py311
12+
requirements:
13+
- vllm==0.2.1.post1
14+
resources:
15+
accelerator: A10G
16+
memory: 25Gi
17+
use_gpu: true
18+
runtime:
19+
predict_concurrency: 256
20+
system_packages: []

llama/llama-7b-vllm/model/__init__.py

Whitespace-only changes.

llama/llama-7b-vllm/model/model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import uuid
2+
from typing import Any
3+
4+
from vllm import SamplingParams
5+
from vllm.engine.arg_utils import AsyncEngineArgs
6+
from vllm.engine.async_llm_engine import AsyncLLMEngine
7+
8+
9+
class Model:
10+
def __init__(self, **kwargs) -> None:
11+
self.engine_args = kwargs["config"]["model_metadata"]["engine_args"]
12+
self.prompt_format = kwargs["config"]["model_metadata"]["prompt_format"]
13+
14+
def load(self) -> None:
15+
self.llm_engine = AsyncLLMEngine.from_engine_args(
16+
AsyncEngineArgs(**self.engine_args)
17+
)
18+
19+
async def predict(self, request: dict) -> Any:
20+
prompt = request.pop("prompt")
21+
stream = request.pop("stream", True)
22+
formatted_prompt = self.prompt_format.replace("{prompt}", prompt)
23+
24+
generate_args = {
25+
"n": 1,
26+
"best_of": 1,
27+
"max_tokens": 512,
28+
"temperature": 1.0,
29+
"top_p": 0.95,
30+
"top_k": 50,
31+
"frequency_penalty": 1.0,
32+
"presence_penalty": 1.0,
33+
"use_beam_search": False,
34+
}
35+
generate_args.update(request)
36+
37+
sampling_params = SamplingParams(**generate_args)
38+
idx = str(uuid.uuid4().hex)
39+
vllm_generator = self.llm_engine.generate(
40+
formatted_prompt, sampling_params, idx
41+
)
42+
43+
async def generator():
44+
full_text = ""
45+
async for output in vllm_generator:
46+
text = output.outputs[0].text
47+
delta = text[len(full_text) :]
48+
full_text = text
49+
yield delta
50+
51+
if stream:
52+
return generator()
53+
else:
54+
full_text = ""
55+
async for delta in generator():
56+
full_text += delta
57+
return {"text": full_text}

0 commit comments

Comments
 (0)