Skip to content

Commit c06d9e1

Browse files
committed
improve time codes and add optional chunking on generated subtitle segments
1 parent 59bdd89 commit c06d9e1

File tree

5 files changed

+176
-34
lines changed

5 files changed

+176
-34
lines changed

Pipfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ tensorflow = ">=1.15.5,<2.12"
8484
termcolor = "==1.1.0"
8585
toml = "==0.10.0"
8686
toolz = "==0.9.0"
87-
torch = "<2.2.0"
87+
torch = "<2.3.0"
88+
torchaudio = "<2.3.0"
8889
transformers = "<4.27.0"
8990
urllib3 = "~=1.26.5"
9091
wrapt = "==1.14.0"

requirements-llm.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
sentencepiece~=0.1.95
22
torch<2.3.0
3+
torchaudio<2.3.0
34
transformers<4.37.0
4-
openai-whisper==20240930
5+
openai-whisper==20240930

subaligner/__main__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ def main():
187187
default=None,
188188
help="Optional text to provide the transcribing context or specific phrases"
189189
)
190+
parser.add_argument(
191+
"-mcl",
192+
"--max_char_length",
193+
type=int,
194+
default=None,
195+
help="Maximum number of characters for each generated subtitle segment"
196+
)
190197
from subaligner.llm import TranslationRecipe
191198
from subaligner.llm import HelsinkiNLPFlavour
192199
parser.add_argument(
@@ -356,9 +363,15 @@ def main():
356363
from subaligner.transcriber import Transcriber
357364
transcriber = Transcriber(recipe=FLAGS.transcription_recipe, flavour=FLAGS.transcription_flavour)
358365
if "_transcribe_temp" in local_subtitle_path:
359-
subtitle, frame_rate = transcriber.transcribe(video_file_path=local_video_path, language_code=stretch_in_lang, initial_prompt=FLAGS.initial_prompt)
366+
subtitle, frame_rate = transcriber.transcribe(video_file_path=local_video_path,
367+
language_code=stretch_in_lang,
368+
initial_prompt=FLAGS.initial_prompt,
369+
max_char_length=FLAGS.max_char_length)
360370
else:
361-
subtitle, frame_rate = transcriber.transcribe_with_subtitle_as_prompts(video_file_path=local_video_path, subtitle_file_path=local_subtitle_path, language_code=stretch_in_lang)
371+
subtitle, frame_rate = transcriber.transcribe_with_subtitle_as_prompts(video_file_path=local_video_path,
372+
subtitle_file_path=local_subtitle_path,
373+
language_code=stretch_in_lang,
374+
max_char_length=FLAGS.max_char_length)
362375
aligned_subs = subtitle.subs
363376
else:
364377
print("ERROR: Unknown mode {}".format(FLAGS.mode))

subaligner/transcriber.py

Lines changed: 139 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
import os
22
import whisper
33
import torch
4-
from typing import Tuple, Optional
4+
import concurrent.futures
5+
import math
6+
import multiprocessing as mp
7+
import torchaudio
8+
import numpy as np
9+
from functools import partial
10+
from threading import Lock
11+
from typing import Tuple, Optional, Dict, List
512
from pysrt import SubRipTime
13+
from whisper import Whisper
614
from whisper.tokenizer import LANGUAGES
15+
from tqdm import tqdm
716
from .subtitle import Subtitle
817
from .media_helper import MediaHelper
918
from .llm import TranscriptionRecipe, WhisperFlavour
@@ -38,14 +47,20 @@ def __init__(self, recipe: str = TranscriptionRecipe.WHISPER.value, flavour: str
3847
self.__flavour = flavour
3948
self.__media_helper = MediaHelper()
4049
self.__LOGGER = Logger().get_logger(__name__)
50+
self.__lock = Lock()
4151

42-
def transcribe(self, video_file_path: str, language_code: str, initial_prompt: Optional[str] = None) -> Tuple[Subtitle, Optional[float]]:
52+
def transcribe(self,
53+
video_file_path: str,
54+
language_code: str,
55+
initial_prompt: Optional[str] = None,
56+
max_char_length: Optional[int] = None) -> Tuple[Subtitle, Optional[float]]:
4357
"""Transcribe an audiovisual file and generate subtitles.
4458
4559
Arguments:
4660
video_file_path {string} -- The input video file path.
4761
language_code {string} -- An alpha 3 language code derived from ISO 639-3.
48-
initial_prompt {string} -- Optional text to provide the transcribing context or specific phrases.
62+
initial_prompt {string} -- Optional Text to provide the transcribing context or specific phrases.
63+
max_char_length {int} -- Optional Maximum number of characters for each generated subtitle segment.
4964
5065
Returns:
5166
tuple: Generated subtitle after transcription and the detected frame rate
@@ -64,14 +79,24 @@ def transcribe(self, video_file_path: str, language_code: str, initial_prompt: O
6479
self.__LOGGER.info("Start transcribing the audio...")
6580
verbose = False if Logger.VERBOSE and not Logger.QUIET else None
6681
self.__LOGGER.debug("Prompting with: '%s'" % initial_prompt)
67-
result = self.__model.transcribe(audio, task="transcribe", language=LANGUAGES[lang], verbose=verbose, initial_prompt=initial_prompt)
82+
result = self.__model.transcribe(audio,
83+
task="transcribe",
84+
language=LANGUAGES[lang],
85+
verbose=verbose,
86+
word_timestamps=True,
87+
initial_prompt=initial_prompt)
6888
self.__LOGGER.info("Finished transcribing the audio")
6989
srt_str = ""
70-
for i, segment in enumerate(result["segments"], start=1):
71-
srt_str += f"{i}\n" \
72-
f"{Utils.format_timestamp(segment['start'])} --> {Utils.format_timestamp(segment['end'])}\n" \
73-
f"{segment['text'].strip().replace('-->', '->')}\n" \
74-
"\n"
90+
srt_idx = 1
91+
for segment in result["segments"]:
92+
if max_char_length is not None and len(segment["text"]) > max_char_length:
93+
srt_str, srt_idx = self._chunk_segment(segment, srt_str, srt_idx, max_char_length)
94+
else:
95+
srt_str += f"{srt_idx}\n" \
96+
f"{Utils.format_timestamp(segment['words'][0]['start'])} --> {Utils.format_timestamp(segment['words'][-1]['end'])}\n" \
97+
f"{segment['text'].strip().replace('-->', '->')}\n" \
98+
"\n"
99+
srt_idx += 1
75100
subtitle = Subtitle.load_subrip_str(srt_str)
76101
subtitle, frame_rate = self.__on_frame_timecodes(subtitle, video_file_path)
77102
self.__LOGGER.debug("Generated the raw subtitle")
@@ -82,13 +107,19 @@ def transcribe(self, video_file_path: str, language_code: str, initial_prompt: O
82107
else:
83108
raise NotImplementedError(f"{self.__recipe} ({self.__flavour}) is not supported")
84109

85-
def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_file_path: str, language_code: str) -> Tuple[Subtitle, Optional[float]]:
86-
"""Transcribe an audiovisual file and generate subtitles using the original subtitle as prompts.
110+
def transcribe_with_subtitle_as_prompts(self,
111+
video_file_path: str,
112+
subtitle_file_path: str,
113+
language_code: str,
114+
max_char_length: Optional[int] = None) -> Tuple[Subtitle, Optional[float]]:
115+
"""Transcribe an audiovisual file and generate subtitles using the original subtitle (with accurate time codes) as prompts.
116+
87117
88118
Arguments:
89119
video_file_path {string} -- The input video file path.
90120
subtitle_file_path {string} -- The input subtitle file path to provide prompts.
91121
language_code {string} -- An alpha 3 language code derived from ISO 639-3.
122+
max_char_length {int} -- Optional Maximum number of characters for each generated subtitle segment.
92123
93124
Returns:
94125
tuple: Generated subtitle after transcription and the detected frame rate
@@ -104,27 +135,54 @@ def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_fil
104135
f'"{language_code}" is not supported by {self.__recipe} ({self.__flavour})')
105136
audio_file_path = self.__media_helper.extract_audio(video_file_path, True, 16000)
106137
subtitle = Subtitle.load(subtitle_file_path)
107-
segment_paths = []
138+
segment_paths: List[str] = []
108139
try:
109140
srt_str = ""
110141
srt_idx = 1
111142
self.__LOGGER.info("Start transcribing the audio...")
112-
verbose = False if Logger.VERBOSE and not Logger.QUIET else None
113-
for sub in subtitle.subs:
143+
segment_paths = []
144+
args = []
145+
longest_segment_char_length = 0
146+
for sub in tqdm(subtitle.subs, desc="Extracting audio segments"):
114147
segment_path, _ = self.__media_helper.extract_audio_from_start_to_end(audio_file_path, str(sub.start), str(sub.end))
115148
segment_paths.append(segment_path)
116-
audio = whisper.load_audio(segment_path)
117-
result = self.__model.transcribe(audio, task="transcribe", language=LANGUAGES[lang], verbose=verbose, initial_prompt=sub.text)
149+
args.append((segment_path, sub.text, self.__lock, self.__LOGGER))
150+
if len(sub.text) > longest_segment_char_length:
151+
longest_segment_char_length = len(sub.text)
152+
max_subtitle_char_length = max_char_length or longest_segment_char_length
153+
154+
max_workers = math.ceil(float(os.getenv("MAX_WORKERS", mp.cpu_count() / 2)))
155+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
156+
results = list(executor.map(partial(self._whisper_transcribe, model=self.__model, lang=lang), args))
157+
for sub, result in zip(subtitle.subs, results):
118158
original_start_in_secs = sub.start.hours * 3600 + sub.start.minutes * 60 + sub.start.seconds + sub.start.milliseconds / 1000.0
119159
original_end_in_secs = sub.end.hours * 3600 + sub.end.minutes * 60 + sub.end.seconds + sub.end.milliseconds / 1000.0
120-
for segment in result["segments"]:
121-
if segment["end"] <= segment["start"]:
122-
continue
160+
if len(result["segments"]) == 0:
123161
srt_str += f"{srt_idx}\n" \
124-
f"{Utils.format_timestamp(original_start_in_secs + segment['start'])} --> {Utils.format_timestamp(min(original_start_in_secs + segment['end'], original_end_in_secs))}\n" \
125-
f"{segment['text'].strip().replace('-->', '->')}\n" \
162+
f"{Utils.format_timestamp(original_start_in_secs)} --> {Utils.format_timestamp(original_end_in_secs)}\n" \
163+
f"{sub.text.strip().replace('-->', '->')}\n" \
126164
"\n"
127165
srt_idx += 1
166+
else:
167+
for segment in result["segments"]:
168+
if segment["end"] <= segment["start"]:
169+
continue
170+
segment_end = min(original_start_in_secs + segment["end"], original_end_in_secs)
171+
if len(segment["text"]) > max_subtitle_char_length:
172+
srt_str, srt_idx = self._chunk_segment(segment,
173+
srt_str,
174+
srt_idx,
175+
max_subtitle_char_length,
176+
original_start_in_secs,
177+
original_end_in_secs)
178+
else:
179+
srt_str += f"{srt_idx}\n" \
180+
f"{Utils.format_timestamp(original_start_in_secs + segment['start'])} --> {Utils.format_timestamp(segment_end)}\n" \
181+
f"{segment['text'].strip().replace('-->', '->')}\n" \
182+
"\n"
183+
srt_idx += 1
184+
if segment_end == original_end_in_secs:
185+
break
128186
self.__LOGGER.info("Finished transcribing the audio")
129187
subtitle = Subtitle.load_subrip_str(srt_str)
130188
subtitle, frame_rate = self.__on_frame_timecodes(subtitle, video_file_path)
@@ -139,6 +197,66 @@ def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_fil
139197
else:
140198
raise NotImplementedError(f"{self.__recipe} ({self.__flavour}) is not supported")
141199

200+
@staticmethod
201+
def _whisper_transcribe(args: Tuple, model: Whisper, lang: str) -> Dict:
202+
segment_path, sub_text, lock, logger = args
203+
verbose = False if Logger.VERBOSE and not Logger.QUIET else None
204+
try:
205+
waveform, _ = torchaudio.load(segment_path)
206+
if waveform.shape[0] > 1:
207+
waveform = waveform.mean(dim=0)
208+
waveform = waveform.numpy().astype(np.float32)
209+
with lock:
210+
result = model.transcribe(waveform,
211+
task="transcribe",
212+
language=LANGUAGES[lang],
213+
verbose=verbose,
214+
initial_prompt=sub_text,
215+
word_timestamps=True)
216+
logger.debug("Segment transcribed : %s", result)
217+
return result
218+
except Exception as e:
219+
logger.warning(f"Error while transcribing segment: {e}")
220+
return {"segments": []}
221+
222+
@staticmethod
223+
def _chunk_segment(segment: Dict,
224+
srt_str: str,
225+
srt_idx: int,
226+
max_subtitle_char_length: int,
227+
start_offset: float = 0.0,
228+
end_ceiling: float = float("inf")) -> Tuple[str, int]:
229+
chunked_text = ""
230+
chunk_start_in_secs = 0.0
231+
chunk_end_in_secs = 0.0
232+
chunk_char_length = 0
233+
234+
for word in segment["words"]:
235+
if chunk_char_length + len(word["word"]) > max_subtitle_char_length and chunked_text.strip() != "":
236+
srt_str += f"{srt_idx}\n" \
237+
f"{Utils.format_timestamp(start_offset + chunk_start_in_secs)} --> {Utils.format_timestamp(min(start_offset + chunk_end_in_secs, end_ceiling))}\n" \
238+
f"{chunked_text.strip().replace('-->', '->')}\n" \
239+
"\n"
240+
srt_idx += 1
241+
chunked_text = word["word"]
242+
chunk_start_in_secs = word["start"]
243+
chunk_char_length = len(word["word"])
244+
else:
245+
if chunk_start_in_secs == 0.0:
246+
chunk_start_in_secs = word["start"]
247+
chunked_text += word["word"]
248+
chunk_char_length += len(word["word"])
249+
chunk_end_in_secs = word["end"]
250+
251+
if len(chunked_text) > 0:
252+
srt_str += f"{srt_idx}\n" \
253+
f"{Utils.format_timestamp(start_offset + chunk_start_in_secs)} --> {Utils.format_timestamp(min(start_offset + chunk_end_in_secs, end_ceiling))}\n" \
254+
f"{chunked_text.strip().replace('-->', '->')}\n" \
255+
"\n"
256+
srt_idx += 1
257+
258+
return srt_str, srt_idx
259+
142260
def __on_frame_timecodes(self, subtitle: Subtitle, video_file_path: str) -> Tuple[Subtitle, Optional[float]]:
143261
frame_rate = None
144262
try:

tox.ini

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,41 @@
11
[tox]
22
envlist =
3-
py36,
4-
py37,
5-
py38
3+
py38,
4+
py39,
5+
py310,
6+
py311
67
skipsdist=True
78
skip_missing_interpreters = True
89

910
[darglint]
1011
ignore=DAR101
1112

12-
[testenv:py36]
13-
basepython = python3.6
13+
[testenv:py38]
14+
basepython = python3.8
1415
whitelist_externals = /bin/bash
1516
commands =
1617
bash -c \'cat requirements.txt | xargs -L 1 pip install\'
1718
bash -c \'cat requirements-dev.txt | xargs -L 1 pip install'
1819
python -m unittest discover
1920

20-
[testenv:py37]
21-
basepython = python3.7
21+
[testenv:py39]
22+
basepython = python3.9
2223
whitelist_externals = /bin/bash
2324
commands =
2425
bash -c \'cat requirements.txt | xargs -L 1 pip install\'
2526
bash -c \'cat requirements-dev.txt | xargs -L 1 pip install'
2627
python -m unittest discover
2728

28-
[testenv:py38]
29-
basepython = python3.8
29+
[testenv:py310]
30+
basepython = python3.10
31+
whitelist_externals = /bin/bash
32+
commands =
33+
bash -c \'cat requirements.txt | xargs -L 1 pip install\'
34+
bash -c \'cat requirements-dev.txt | xargs -L 1 pip install'
35+
python -m unittest discover
36+
37+
[testenv:py311]
38+
basepython = python3.11
3039
whitelist_externals = /bin/bash
3140
commands =
3241
bash -c \'cat requirements.txt | xargs -L 1 pip install\'

0 commit comments

Comments
 (0)