|
| 1 | +import base64 |
| 2 | +import gc |
| 3 | +import os |
| 4 | +import re |
| 5 | +from tempfile import NamedTemporaryFile |
| 6 | + |
| 7 | +import torch |
| 8 | +from async_batcher.batcher import AsyncBatcher |
| 9 | +from huggingface_hub import snapshot_download |
| 10 | +from run import WhisperTRTLLM |
| 11 | +from torch import Tensor |
| 12 | +from whisper_utils import log_mel_spectrogram |
| 13 | + |
| 14 | +TEXT_PREFIX = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" |
| 15 | + |
| 16 | +# Num beams is the number of paths the model traverses before transcribing the text |
| 17 | +NUM_BEAMS = 3 |
| 18 | + |
| 19 | +# Max queue time is the amount of time in seconds to wait to fill the batch |
| 20 | +MAX_QUEUE_TIME = 0.25 |
| 21 | + |
| 22 | +# Maximum size of the batch. This is dictated by the compiled engine. |
| 23 | +MAX_BATCH_SIZE = 8 |
| 24 | + |
| 25 | + |
| 26 | +class MlBatcher(AsyncBatcher[list[Tensor], list[str]]): |
| 27 | + def __init__(self, model, *args, **kwargs): |
| 28 | + super().__init__(*args, **kwargs) |
| 29 | + self.model: WhisperTRTLLM = model |
| 30 | + |
| 31 | + def process_batch(self, batch: list[Tensor]) -> list[float]: |
| 32 | + # Need to pad the batch up to the maximum batch size |
| 33 | + features = torch.cat(batch, dim=0).type(torch.float16) |
| 34 | + return self.model.process_batch(features, TEXT_PREFIX, NUM_BEAMS) |
| 35 | + |
| 36 | + |
| 37 | +class Model: |
| 38 | + def __init__(self, **kwargs): |
| 39 | + self._data_dir = kwargs["data_dir"] |
| 40 | + self._model = None |
| 41 | + self._batcher = None |
| 42 | + gc.freeze() |
| 43 | + |
| 44 | + def load(self): |
| 45 | + # Download the compiled model from hugging face hub |
| 46 | + snapshot_download( |
| 47 | + "baseten/trtllm-whisper-a10g-large-v2-1", |
| 48 | + local_dir=self._data_dir, |
| 49 | + max_workers=4, |
| 50 | + ) |
| 51 | + |
| 52 | + self._model = WhisperTRTLLM(f"{self._data_dir}") |
| 53 | + self._batcher = MlBatcher( |
| 54 | + model=self._model, |
| 55 | + max_batch_size=MAX_BATCH_SIZE, |
| 56 | + max_queue_time=MAX_QUEUE_TIME, |
| 57 | + ) |
| 58 | + |
| 59 | + def base64_to_wav(self, base64_string, output_file_path): |
| 60 | + binary_data = base64.b64decode(base64_string) |
| 61 | + with open(output_file_path, "wb") as wav_file: |
| 62 | + wav_file.write(binary_data) |
| 63 | + return output_file_path |
| 64 | + |
| 65 | + async def predict(self, model_input: dict): |
| 66 | + # TODO: figure out what the normalizer is for |
| 67 | + normalizer = None |
| 68 | + with NamedTemporaryFile() as fp: |
| 69 | + self.base64_to_wav(model_input["audio"], fp.name) |
| 70 | + mel, total_duration = log_mel_spectrogram( |
| 71 | + fp.name, |
| 72 | + self._model.n_mels, |
| 73 | + device="cuda", |
| 74 | + return_duration=True, |
| 75 | + mel_filters_dir=f"{self._data_dir}/assets", |
| 76 | + ) |
| 77 | + mel = mel.type(torch.float16) |
| 78 | + mel = mel.unsqueeze(0) |
| 79 | + prediction = await self._batcher.process(item=mel) |
| 80 | + |
| 81 | + # remove all special tokens in the prediction |
| 82 | + prediction = re.sub(r"<\|.*?\|>", "", prediction) |
| 83 | + if normalizer: |
| 84 | + prediction = normalizer(prediction) |
| 85 | + return {"text": prediction.strip(), "duration": total_duration} |
0 commit comments