-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathtranscribe.py
More file actions
635 lines (552 loc) · 26.1 KB
/
transcribe.py
File metadata and controls
635 lines (552 loc) · 26.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
# Copyright © 2023 Apple Inc.
import sys
import time
import warnings
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
import tqdm
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .load_models import load_model
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, get_tokenizer
def _format_timestamp(seconds: float):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def _get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ModelHolder:
model = None
model_path = None
@classmethod
def get_model(cls, model_path: str, dtype: mx.Dtype):
if cls.model is None or model_path != cls.model_path:
cls.model = load_model(model_path, dtype=dtype)
cls.model_path = model_path
return cls.model
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
path_or_hf_repo: str = "mlx-community/whisper-turbo",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
audio: Union[str, np.ndarray, mx.array]
The path to the audio file to open, or the audio waveform
path_or_hf_repo: str
The localpath to the Whisper model or HF Hub repo with the MLX converted weights.
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
# Track model loading time
model_load_start = time.time()
model = ModelHolder.get_model(path_or_hf_repo, dtype)
model_load_time = time.time() - model_load_start
# Track mel spectrogram computation time
mel_start = time.time()
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
mel_time = time.time() - mel_start
content_frames = mel.shape[-2] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if verbose:
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
make_safe = lambda x: x.encode(system_encoding, errors="replace").decode(
system_encoding
)
else:
make_safe = lambda x: x
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. "
"Use the `language` decoding option to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES, axis=-2).astype(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
else:
seek_points[-1] = min(content_frames, seek_points[-1])
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
# Initialize timing metrics
total_inference_time = 0.0
total_tokens_generated = 0
total_inference_steps = (
0 # Total decoder forward passes (comparable to whisper.cpp "runs")
)
segment_timings = []
def decode_with_fallback(segment: mx.array) -> DecodingResult:
nonlocal total_inference_time, total_tokens_generated, total_inference_steps
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
segment_start_time = time.time()
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
segment_inference_time = time.time() - segment_start_time
total_inference_time += segment_inference_time
num_output_tokens = len(decode_result.tokens)
num_inference_steps = decode_result.num_inference_steps
total_tokens_generated += num_output_tokens
total_inference_steps += num_inference_steps
segment_timings.append(
{
"time": segment_inference_time,
"output_tokens": num_output_tokens,
"inference_steps": num_inference_steps,
"output_tokens_per_sec": (
num_output_tokens / segment_inference_time
if segment_inference_time > 0
else 0
),
"inference_steps_per_sec": (
num_inference_steps / segment_inference_time
if segment_inference_time > 0
else 0
),
}
)
return decode_result
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: mx.array, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
for seek_clip_start, seek_clip_end in seek_clips:
while seek < seek_clip_end:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(
N_FRAMES, content_frames - seek, seek_clip_end - seek
)
mel_segment = mel[seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = np.array(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += (
segment_size # fast-forward to the next segment boundary
)
continue
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [
w for w in segment["words"] if w["word"] not in punctuation
]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens = tokens >= tokenizer.timestamp_begin
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [
False,
True,
]
consecutive = np.where(
np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:])
)[0]
consecutive += 1
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset
+ start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero()[0]]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if (
last_word_end is not None
and last_word_end > time_offset
):
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(
first_segment
):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = _get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
start, end, text = (
segment["start"],
segment["end"],
segment["text"],
)
line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if (
segment["start"] == segment["end"]
or segment["text"].strip() == ""
):
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[
token
for segment in current_segments
for token in segment["tokens"]
]
)
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
# Print detailed inference metrics
total_time = model_load_time + mel_time + total_inference_time
rtf = (total_time / content_duration) if content_duration > 0 else 0
if verbose is not False:
print("\n" + "=" * 80)
print("BENCHMARK METRICS")
print("=" * 80)
print(f"Model load time: {model_load_time * 1000:.2f} ms")
print(f"Mel spectrogram time: {mel_time * 1000:.2f} ms")
print(f"Inference time: {total_inference_time * 1000:.2f} ms")
print(f"Total time: {total_time * 1000:.2f} ms")
print(f"Audio duration: {content_duration:.2f} s")
print(f"RTF (Real-Time Factor): {rtf:.3f}")
print(f"\nTotal output tokens: {total_tokens_generated}")
print(
f"Total inference steps (decoder forward passes): {total_inference_steps}"
)
if total_inference_time > 0:
print(
f"Average output tokens/sec: {total_tokens_generated / total_inference_time:.2f}"
)
print(
f"Average inference steps/sec: {total_inference_steps / total_inference_time:.2f}"
)
print(f"Number of segments: {len(segment_timings)}")
if verbose and len(segment_timings) > 0:
print("\nPer-segment details:")
print(
f"{'Seg#':<6} {'Out':<6} {'Steps':<7} {'Time(s)':<10} {'Out/s':<10} {'Steps/s':<10}"
)
print("-" * 80)
for i, timing in enumerate(segment_timings):
print(
f"{i:<6} {timing['output_tokens']:<6} {timing['inference_steps']:<7} "
f"{timing['time']:<10.3f} {timing['output_tokens_per_sec']:<10.2f} "
f"{timing['inference_steps_per_sec']:<10.2f}"
)
print("\nNOTE:")
print("- RTF < 1.0 means faster than real-time")
print("- 'Output tokens': Final text tokens (excluding special tokens)")
print(
"- 'Inference steps': Total decoder forward passes (comparable to whisper.cpp 'runs')"
)
print("=" * 80 + "\n")
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)