Skip to content

Commit cd732f5

Browse files
tuhinsTuhin Srivastava
andauthored
mixtral-8x7b-instruct-vllm-a100-t-tp2 (#243)
Mixtral 8x7B — VLLM TP2 — A100:2 --------- Co-authored-by: Tuhin Srivastava <[email protected]>
1 parent 1743dce commit cd732f5

File tree

4 files changed

+124
-0
lines changed

4 files changed

+124
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Mixtral 8x7B Instruct Truss
2+
3+
This is a [Truss](https://truss.baseten.co/) for Mixtral 8x7B Instruct. Mixtral 8x7B Instruct parameter language model released by [Mistral AI](https://mistral.ai/). It is a mixture-of-experts (MoE) model. This README will walk you through how to deploy this Truss on Baseten to get your own instance of it.
4+
5+
6+
## Deployment
7+
8+
First, clone this repository:
9+
10+
```sh
11+
git clone https://github.com/basetenlabs/truss-examples/
12+
cd mixtral-8x7b-instruct-vllm
13+
```
14+
15+
Before deployment:
16+
17+
1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
18+
2. Install the latest version of Truss: `pip install --upgrade truss`
19+
20+
With `mixtral-8x7b-instruct-vllm` as your working directory, you can deploy the model with:
21+
22+
```sh
23+
truss push --publish
24+
```
25+
26+
Paste your Baseten API key if prompted.
27+
28+
For more information, see [Truss documentation](https://truss.baseten.co).
29+
30+
### Hardware notes
31+
32+
You need two A100s to run Mixtral at `fp16`. If you need access to A100s, please [contact us](mailto:[email protected]).
33+
34+
## Mixtral 8x7B Instruct API documentation
35+
36+
This section provides an overview of the Mixtral 8x7B Instruct API, its parameters, and how to use it. The API consists of a single route named `predict`, which you can invoke to generate text based on the provided prompt.
37+
38+
### API route: `predict`
39+
40+
The `predict` route is the primary method for generating text completions based on a given prompt. It takes several parameters:
41+
42+
- __prompt__: The input text that you want the model to generate a response for.
43+
- __stream__ (optional, default=False): A boolean determining whether the model should stream a response back. When `True`, the API returns generated text as it becomes available.
44+
45+
## Example usage
46+
47+
```sh
48+
truss predict -d '{"prompt": "What is the Mistral wind?"}'
49+
```
50+
51+
You can also invoke your model via a REST API:
52+
53+
```
54+
curl -X POST " https://app.baseten.co/model_versions/YOUR_MODEL_VERSION_ID/predict" \
55+
-H "Content-Type: application/json" \
56+
-H 'Authorization: Api-Key {YOUR_API_KEY}' \
57+
-d '{
58+
"prompt": "What is the meaning of life? Answer in substantial detail with multiple examples from famous philosophies, religions, and schools of thought.",
59+
"stream": true,
60+
"max_tokens": 4096
61+
}' --no-buffer
62+
```
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
environment_variables: {}
2+
external_package_dirs: []
3+
model_name: Mixtral 8x7B — VLLM TP2 — A100:2
4+
python_version: py310
5+
requirements:
6+
- vllm
7+
resources:
8+
accelerator: A100:2
9+
use_gpu: true
10+
runtime:
11+
predict_concurrency: 128
12+
secrets: {}
13+
system_packages: []

mistral/mixtral-8x7b-instruct-vllm-a100-t-tp2/model/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import subprocess
2+
import uuid
3+
4+
from vllm import SamplingParams
5+
from vllm.engine.arg_utils import AsyncEngineArgs
6+
from vllm.engine.async_llm_engine import AsyncLLMEngine
7+
8+
9+
class Model:
10+
def __init__(self, **kwargs):
11+
self.model = None
12+
self.llm_engine = None
13+
self.model_args = None
14+
15+
command = "ray start --head"
16+
subprocess.check_output(command, shell=True, text=True)
17+
18+
def load(self):
19+
self.model_args = AsyncEngineArgs(
20+
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
21+
tensor_parallel_size=2,
22+
gpu_memory_utilization=0.95,
23+
max_model_len=4096,
24+
)
25+
self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args)
26+
27+
async def predict(self, model_input):
28+
prompt = model_input.pop("prompt")
29+
stream = model_input.pop("stream", True)
30+
31+
sampling_params = SamplingParams(**model_input)
32+
idx = str(uuid.uuid4().hex)
33+
vllm_generator = self.llm_engine.generate(prompt, sampling_params, idx)
34+
35+
async def generator():
36+
full_text = ""
37+
async for output in vllm_generator:
38+
text = output.outputs[0].text
39+
delta = text[len(full_text) :]
40+
full_text = text
41+
yield delta
42+
43+
if stream:
44+
return generator()
45+
else:
46+
full_text = ""
47+
async for delta in generator():
48+
full_text += delta
49+
return {"text": full_text}

0 commit comments

Comments
 (0)