Skip to content

Commit 0487fe9

Browse files
authored
Whisper Torchserve truss (#225)
1 parent 0f9ed68 commit 0487fe9

File tree

5 files changed

+181
-0
lines changed

5 files changed

+181
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Whisper Torchserve
2+
3+
This truss allows you to run a whisper model using [torchserve](https://pytorch.org/serve/) as the backend on truss.
4+
5+
6+
## Deployment
7+
8+
Before deployment:
9+
10+
1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
11+
2. Install the latest version of Truss: `pip install --upgrade truss`
12+
13+
With `whisper/whisper-torchserve` as your working directory, you can deploy the model with:
14+
15+
```
16+
truss push
17+
```
18+
19+
Paste your Baseten API key if prompted.
20+
21+
For more information, see [Truss documentation](https://truss.baseten.co).
22+
23+
## Model Inputs
24+
25+
The model takes in one input:
26+
- __audio__: An audio file as a base64 string
27+
28+
## Few thing to note
29+
Torchserve requires a compiled `.mar` file in order to serve the model. Here is a [README](https://github.com/pytorch/serve/blob/master/model-archiver/README.md) providing a brief explanation for generating this file. Once the `.mar` file is generated it needs to get placed in the `data/model_store` directory. Also in the `data/` directory is a configuration file for torchserve called `config.properties`. That file looks something like this:
30+
31+
```
32+
inference_address=http://0.0.0.0:8888
33+
batch_size=4
34+
ipex_enable=true
35+
async_logging=true
36+
37+
models={\
38+
"whisper_base": {\
39+
"1.0": {\
40+
"defaultVersion": true,\
41+
"marName": "whisper_base.mar",\
42+
"minWorkers": 1,\
43+
"maxWorkers": 2,\
44+
"batchSize": 4,\
45+
"maxBatchDelay": 500,\
46+
"responseTimeout": 24\
47+
}\
48+
}\
49+
}
50+
```
51+
52+
Here you can specify the `batchSize` as well as the name of your mar file using `marName`. When torchserve starts, it will looks for the mar file inside the `data/model_store` directory with the `marName` defined above.
53+
54+
## Invoking the model
55+
56+
Here is an example in Python:
57+
58+
```python
59+
import requests
60+
import base64
61+
62+
def wav_to_base64(file_path):
63+
with open(file_path, "rb") as wav_file:
64+
binary_data = wav_file.read()
65+
base64_data = base64.b64encode(binary_data)
66+
base64_string = base64_data.decode("utf-8")
67+
return base64_string
68+
69+
resp = requests.post(
70+
"https://model-<model-id>.api.baseten.co/development/predict",
71+
headers={"Authorization": "Api-Key BASETEN-API-KEY"},
72+
json={"audio": wav_to_base64("/path/to/audio-file/60-sec.wav")},
73+
)
74+
75+
print(resp.json())
76+
```
77+
78+
Here is a sample output:
79+
80+
```json
81+
{"output": "Let me make it clear. His conduct is unacceptable. He's unfit. And be careful of what you're gonna get. He doesn't care for the American people. It's Donald Trump first. This is what I want people to understand. These people have... I mean, she has no idea what the hell the names of those provinces are, but she wants to send our sons and daughters and our troops and our military equipment to go fight it. Look at the blank expression. She doesn't know the names of the provinces. You do this at every debate. You say, no, don't interrupt me. I didn't interrupt you."}
82+
```
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
environment_variables: {}
2+
external_package_dirs: []
3+
model_metadata: {}
4+
model_name: Whisper Torchserve
5+
python_version: py310
6+
requirements:
7+
- torch==2.1.0
8+
- torchserve==0.9.0
9+
- ffmpeg-python==0.2.0
10+
- transformers==4.37.2
11+
- nvgpu==0.10.0
12+
- httpx==0.27.0
13+
resources:
14+
accelerator: T4
15+
use_gpu: true
16+
model_cache:
17+
- repo_id: htrivedi99/whisper-torchserve
18+
secrets: {}
19+
system_packages:
20+
- ffmpeg
21+
- openjdk-11-jdk
22+
runtime:
23+
predict_concurrency: 128
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
inference_address=http://0.0.0.0:8888
2+
batch_size=4
3+
ipex_enable=true
4+
async_logging=true
5+
6+
models={\
7+
"whisper_base": {\
8+
"1.0": {\
9+
"defaultVersion": true,\
10+
"marName": "whisper_base.mar",\
11+
"minWorkers": 1,\
12+
"maxWorkers": 2,\
13+
"batchSize": 4,\
14+
"maxBatchDelay": 500,\
15+
"responseTimeout": 24\
16+
}\
17+
}\
18+
}
19+
20+
# default_workers_per_model=2

whisper/whisper-torchserve/model/__init__.py

Whitespace-only changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import base64
2+
import multiprocessing
3+
import os
4+
import subprocess
5+
from typing import Dict
6+
7+
import httpx
8+
import requests
9+
from huggingface_hub import snapshot_download
10+
11+
TORCHSERVE_ENDPOINT = "http://0.0.0.0:8888/predictions/whisper_base"
12+
13+
14+
class Model:
15+
def __init__(self, **kwargs):
16+
self._data_dir = kwargs["data_dir"]
17+
self._model = None
18+
19+
def start_tochserver(self):
20+
subprocess.run(
21+
[
22+
"torchserve",
23+
"--start",
24+
"--model-store",
25+
f"{self._data_dir}/model_store",
26+
"--models",
27+
"whisper_base.mar",
28+
"--foreground",
29+
"--no-config-snapshots",
30+
"--ts-config",
31+
f"{self._data_dir}/config.properties",
32+
],
33+
check=True,
34+
)
35+
36+
def load(self):
37+
snapshot_download(
38+
"htrivedi99/whisper-torchserve",
39+
local_dir=os.path.join(self._data_dir, "model_store"),
40+
max_workers=4,
41+
)
42+
print("Downloaded weights succesfully!")
43+
44+
process = multiprocessing.Process(target=self.start_tochserver)
45+
process.start()
46+
47+
async def predict(self, request: Dict):
48+
audio_base64 = request.get("audio")
49+
audio_bytes = base64.b64decode(audio_base64)
50+
51+
async with httpx.AsyncClient() as client:
52+
res = await client.post(
53+
TORCHSERVE_ENDPOINT, files={"data": (None, audio_bytes)}
54+
)
55+
transcription = res.text
56+
return {"output": transcription}

0 commit comments

Comments
 (0)