Skip to content

Commit 9a9c09f

Browse files
Merge pull request #48 from njeffrie/master
Implement batching for Useful Sensors Moonshine
2 parents 66a61a9 + 115205b commit 9a9c09f

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

moonshine/run_eval.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import argparse
22
import os
33
import torch
4-
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq, AutoProcessor, PreTrainedTokenizerFast
4+
from transformers import MoonshineForConditionalGeneration, AutoProcessor
5+
56
import evaluate
67
from normalizer import data_utils
78
import time
@@ -12,9 +13,9 @@
1213
torch.set_float32_matmul_precision('high')
1314

1415
def main(args):
15-
config = AutoConfig.from_pretrained(args.model_id, trust_remote_code=True)
16-
model = AutoModelForSpeechSeq2Seq.from_pretrained(args.model_id, torch_dtype=torch.bfloat16, trust_remote_code=True).to(args.device)
17-
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.model_id, trust_remote_code=True)
16+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17+
model = MoonshineForConditionalGeneration.from_pretrained(args.model_id).to(args.device).to(torch_dtype)
18+
processor = AutoProcessor.from_pretrained(args.model_id)
1819

1920
if args.torch_compile:
2021
model.forward = torch.compile(model.forward, mode=args.compile_mode, fullgraph=True)
@@ -30,16 +31,30 @@ def benchmark(batch, min_new_tokens=None):
3031
# START TIMING
3132
start_time = time.time()
3233

33-
np_arr = np.array(audios)
34-
input_tensor = torch.FloatTensor(np_arr)
35-
moonshine_min_input_size = 1024
36-
padding = moonshine_min_input_size - input_tensor.size()[1]
37-
if padding > 0:
38-
input_tensor = torch.nn.functional.pad(input_tensor, (0, padding))
39-
pred_ids = model(input_tensor.to(args.device).to(torch.bfloat16))
34+
# 1. Pre-Processing
35+
# 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations
36+
padding_size = 0
37+
if minibatch_size != args.batch_size and args.torch_compile:
38+
padding_size = args.batch_size - minibatch_size
39+
padding_audios = [audios[-1] for _ in range(padding_size)]
40+
audios.extend(padding_audios)
41+
42+
inputs = processor(audios, return_tensors="pt", padding=True, sampling_rate=16000).to(args.device).to(torch_dtype)
43+
44+
# Create a mask for output tokens to limit length based on input audio clip length.
45+
# Add 2 to token limits to account for <sot> and <eot>.
46+
token_generation_limits = [len(clip) * 6.5 // 16000 + 2 for clip in audios]
47+
max_new_tokens = torch.tensor(token_generation_limits).reshape((-1, 1)).to(args.device)
48+
49+
pred_ids = model.generate(**inputs, max_new_tokens=max_new_tokens.max())
50+
output_mask = torch.arange(pred_ids.shape[-1]).repeat((pred_ids.shape[0], 1)).to(args.device)
51+
output_mask = output_mask > max_new_tokens
52+
53+
eot_token = model.config.eos_token_id
54+
pred_ids.masked_fill(output_mask, eot_token)
4055

4156
# 3.2 Convert token ids to text transcription
42-
pred_text = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
57+
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)
4358

4459
# END TIMING
4560
runtime = time.time() - start_time
@@ -48,6 +63,7 @@ def benchmark(batch, min_new_tokens=None):
4863
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
4964

5065
# normalize transcriptions with English normalizer
66+
pred_text = pred_text if padding_size == 0 else pred_text[:-padding_size]
5167
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
5268
batch["references"] = batch["norm_text"]
5369
return batch

moonshine/run_moonshine.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
export PYTHONPATH="..":$PYTHONPATH
44

55
MODEL_IDs=("usefulsensors/moonshine-base" "usefulsensors/moonshine-tiny")
6-
BATCH_SIZE=1
6+
BATCH_SIZE=64
77

88
num_models=${#MODEL_IDs[@]}
99

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
torch
2-
transformers
32
evaluate
4-
datasets
53
librosa
64
jiwer
75
einops
6+
datasets==3.2.0
7+
numba==0.60.0
8+
numpy==2.0.2
9+
git+https://github.com/huggingface/transformers.git#egg=transformers

0 commit comments

Comments
 (0)