Skip to content

Commit 4133898

Browse files
bolasimhtrivedi99
andauthored
Add Dynamic batching example (#235)
Co-authored-by: Het Trivedi <[email protected]>
1 parent b496ec2 commit 4133898

File tree

11 files changed

+1171
-0
lines changed

11 files changed

+1171
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.venv/
2+
payload.json
3+
.vscode
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.venv/
2+
payload.json
3+
.vscode
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Dynamic Batching in Truss
2+
3+
This repository contains an implementation designed to enable dynamic batching for machine learning models within the Truss framework. The core of this implementation lies in the `model/model.py` file, which introduces a `MlBatcher` class extending `AsyncBatcher`. This class is responsible for collecting individual prediction requests and processing them in batches, thereby improving throughput and efficiency.
4+
5+
## Key Features
6+
7+
- **Dynamic Batching:** The `MlBatcher` class dynamically batches incoming prediction requests, allowing for more efficient use of resources and faster response times.
8+
- **Asynchronous Processing:** Utilizes asynchronous programming to handle concurrent prediction requests without blocking, ensuring high throughput.
9+
- **Easy Integration:** Designed to be deployed as a normal Truss, making integration into existing projects straightforward.
10+
11+
## Deployment
12+
13+
To deploy this as a normal Truss, ensure you have the Truss CLI installed and configured. Then, follow these steps:
14+
15+
1. Clone this repository to your local machine.
16+
2. Navigate to the repository directory and build the Truss using the command `truss push --publish`.
17+
3. Once the build completes, deploy the Truss to your desired environment.
18+
19+
## Configuration
20+
21+
The `config.yaml` file contains configuration options for the model, including the Python version, required packages, and runtime settings such as `predict_concurrency`. Adjust these settings as needed to optimize performance for your specific use case.
22+
23+
## Testing
24+
25+
The `test.py` file provides an example of how to send concurrent requests to the deployed model for testing purposes. Modify the URL and data as needed to match your deployment.
26+
27+
## Conclusion
28+
29+
This implementation showcases how dynamic batching can be seamlessly integrated into the Truss framework, providing a scalable and efficient solution for handling machine learning inference at scale.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
base_image:
2+
image: baseten/trtllm-server:r23.12_baseten_v0.9.0.dev2024022000
3+
python_executable_path: /usr/bin/python3
4+
model_name: TRT Whisper - Dynamic Batching
5+
python_version: py311
6+
requirements:
7+
- async-batcher==0.2.0
8+
- mpi4py==3.1.5
9+
- pynvml==11.5.0
10+
- huggingface_hub==0.20.3
11+
- tiktoken==0.6.0
12+
- datasets==2.17.1
13+
- kaldialign==0.9
14+
- openai-whisper==20231117
15+
- soundfile==0.12.1
16+
model_cache:
17+
- repo_id: baseten/trtllm-whisper-a10g-large-v2-1
18+
system_packages:
19+
- python3.10-venv
20+
- ffmpeg
21+
resources:
22+
accelerator: A10G
23+
runtime:
24+
predict_concurrency: 256
25+
external_data:
26+
- local_data_path: assets/multilingual.tiktoken
27+
url: https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken
28+
- local_data_path: assets/mel_filters.npz
29+
url: https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz

07-high-performance-dynamic-batching/model/__init__.py

Whitespace-only changes.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import base64
2+
import gc
3+
import os
4+
import re
5+
from tempfile import NamedTemporaryFile
6+
7+
import torch
8+
from async_batcher.batcher import AsyncBatcher
9+
from huggingface_hub import snapshot_download
10+
from run import WhisperTRTLLM
11+
from torch import Tensor
12+
from whisper_utils import log_mel_spectrogram
13+
14+
TEXT_PREFIX = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
15+
16+
# Num beams is the number of paths the model traverses before transcribing the text
17+
NUM_BEAMS = 3
18+
19+
# Max queue time is the amount of time in seconds to wait to fill the batch
20+
MAX_QUEUE_TIME = 0.25
21+
22+
# Maximum size of the batch. This is dictated by the compiled engine.
23+
MAX_BATCH_SIZE = 8
24+
25+
26+
class MlBatcher(AsyncBatcher[list[Tensor], list[str]]):
27+
def __init__(self, model, *args, **kwargs):
28+
super().__init__(*args, **kwargs)
29+
self.model: WhisperTRTLLM = model
30+
31+
def process_batch(self, batch: list[Tensor]) -> list[float]:
32+
# Need to pad the batch up to the maximum batch size
33+
features = torch.cat(batch, dim=0).type(torch.float16)
34+
return self.model.process_batch(features, TEXT_PREFIX, NUM_BEAMS)
35+
36+
37+
class Model:
38+
def __init__(self, **kwargs):
39+
self._data_dir = kwargs["data_dir"]
40+
self._model = None
41+
self._batcher = None
42+
gc.freeze()
43+
44+
def load(self):
45+
# Download the compiled model from hugging face hub
46+
snapshot_download(
47+
"baseten/trtllm-whisper-a10g-large-v2-1",
48+
local_dir=self._data_dir,
49+
max_workers=4,
50+
)
51+
52+
self._model = WhisperTRTLLM(f"{self._data_dir}")
53+
self._batcher = MlBatcher(
54+
model=self._model,
55+
max_batch_size=MAX_BATCH_SIZE,
56+
max_queue_time=MAX_QUEUE_TIME,
57+
)
58+
59+
def base64_to_wav(self, base64_string, output_file_path):
60+
binary_data = base64.b64decode(base64_string)
61+
with open(output_file_path, "wb") as wav_file:
62+
wav_file.write(binary_data)
63+
return output_file_path
64+
65+
async def predict(self, model_input: dict):
66+
# TODO: figure out what the normalizer is for
67+
normalizer = None
68+
with NamedTemporaryFile() as fp:
69+
self.base64_to_wav(model_input["audio"], fp.name)
70+
mel, total_duration = log_mel_spectrogram(
71+
fp.name,
72+
self._model.n_mels,
73+
device="cuda",
74+
return_duration=True,
75+
mel_filters_dir=f"{self._data_dir}/assets",
76+
)
77+
mel = mel.type(torch.float16)
78+
mel = mel.unsqueeze(0)
79+
prediction = await self._batcher.process(item=mel)
80+
81+
# remove all special tokens in the prediction
82+
prediction = re.sub(r"<\|.*?\|>", "", prediction)
83+
if normalizer:
84+
prediction = normalizer(prediction)
85+
return {"text": prediction.strip(), "duration": total_duration}

07-high-performance-dynamic-batching/packages/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)