Skip to content

Commit bc44db8

Browse files
committed
add possibility to use many output formats (srt, csv, etc.)
1 parent 82f9f95 commit bc44db8

File tree

2 files changed

+186
-32
lines changed

2 files changed

+186
-32
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#writing-pyproject-toml
22
[build-system]
3-
requires = ["hatchling"]
4-
build-backend = "hatchling.build"
3+
requires = ["setuptools"]
4+
build-backend = "setuptools.build_meta"
55

66
[project]
77
name = "pytranscript"
8-
version = "0.1.1"
8+
version = "0.2.1"
99
description = "CLI to transcript and translate audio and video files"
1010
readme = "README.md"
1111
requires-python = ">=3.12"

src/pytranscript.py

Lines changed: 183 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,27 @@
22

33
import json
44
import logging
5+
import typing
56
import warnings
67
import wave
78
from dataclasses import dataclass, field
89
from pathlib import Path
9-
from typing import NamedTuple
10+
from typing import TYPE_CHECKING, Literal, NamedTuple
1011

1112
import deep_translator
1213
import ffmpeg
1314
import tap # typed_argument_parser
1415
import vosk
1516
from tqdm import tqdm
1617

18+
if TYPE_CHECKING:
19+
from os import PathLike
20+
type StrPath = str | PathLike[str]
21+
22+
23+
TranscriptFormat = Literal["csv", "json", "srt", "txt", "vtt"]
24+
TRANSCRIPT_FORMATS: tuple[TranscriptFormat] = typing.get_args(TranscriptFormat)
25+
1726

1827
class LineError(NamedTuple):
1928
time: float
@@ -51,6 +60,23 @@ def seconds_to_time(seconds: float) -> str:
5160
return f"{days}d {hours:02d}:{seconds_to_time(minutes)}"
5261

5362

63+
def seconds_to_srt_time(seconds: float) -> str:
64+
"""Convert seconds to SRT time format.
65+
66+
Args:
67+
seconds (float): the number of seconds
68+
69+
Returns:
70+
str: the time in the format "hh:mm:ss,ms"
71+
"""
72+
seconds = float(seconds)
73+
hours, remainder = divmod(seconds, 3600)
74+
minutes, seconds = divmod(remainder, 60)
75+
int_seconds, dec_seconds = str(seconds).split(".")
76+
dec_seconds = dec_seconds[:3]
77+
return f"{int(hours):02d}:{int(minutes):02d}:{int(int_seconds):02d},{dec_seconds}"
78+
79+
5480
@dataclass
5581
class Transcript:
5682
"""A transcript of a video or audio file.
@@ -60,11 +86,20 @@ class Transcript:
6086
time (list[float]): a list of the time of each line in the transcript.
6187
text (list[str]): a list of the text of each line in the transcript.
6288
language (str): the language of the transcript. Default: "auto"
89+
time_end (float): the time of the last line in the transcript. If not specified,
90+
it will be the time of the last line + 5 seconds.
6391
"""
6492

6593
time: list[float] = field(default_factory=list)
6694
text: list[str] = field(default_factory=list)
6795
language: str = "auto"
96+
time_end: float | None = None
97+
98+
@property
99+
def _time_end(self) -> float:
100+
if self.time_end is not None:
101+
return self.time_end
102+
return self.time[-1] + 5
68103

69104
def append(self, time: float, text: str) -> None:
70105
self.time.append(time)
@@ -76,6 +111,12 @@ def __str__(self):
76111
for time, line in zip(self.time, self.text, strict=True)
77112
)
78113

114+
def __len__(self):
115+
return len(self.time)
116+
117+
def __getitem__(self, key):
118+
return self.time[key], self.text[key]
119+
79120
def translate(self, target: str) -> tuple[Transcript, list[LineError]]:
80121
"""Return a translated version of the transcript.
81122
@@ -89,11 +130,13 @@ def translate(self, target: str) -> tuple[Transcript, list[LineError]]:
89130
- Transcript: the translated transcript
90131
- list[LineError]: a list of errors that occurred during the translation
91132
"""
92-
translated = Transcript()
133+
translated = Transcript(time_end=self.time_end)
93134
errors: list[LineError] = []
94135

95136
_iter = zip(self.time, self.text, strict=True)
96-
pbar = tqdm(_iter, total=len(self.time), unit_scale=True, unit="line")
137+
pbar = tqdm(
138+
_iter, total=len(self.time), unit_scale=True, unit="line", desc="Translate"
139+
)
97140
for time, line in pbar:
98141
try:
99142
translator = deep_translator.GoogleTranslator(
@@ -105,9 +148,75 @@ def translate(self, target: str) -> tuple[Transcript, list[LineError]]:
105148
errors.append(LineError(time, line, e))
106149
return translated, errors
107150

151+
def srt_generator(self):
152+
"""Generate the transcript as a string in SRT format, line by line."""
153+
154+
def one_line(start, end, line):
155+
start, end = map(seconds_to_srt_time, (start, end))
156+
return f"{start} --> {end}\n{line}\n\n"
157+
158+
nb_lines = len(self)
159+
for i, (time, line) in enumerate(
160+
zip(self.time, self.text, strict=True), start=1
161+
):
162+
if i == nb_lines:
163+
end = self._time_end
164+
else:
165+
after_time = self.time[i]
166+
end = min(after_time, time + 5)
167+
yield one_line(time, end, line)
168+
169+
def vtt_generator(self):
170+
"""Generate the transcript as a string in VTT format, line by line."""
171+
yield "WEBVTT\n\n"
172+
for srt_line in self.srt_generator():
173+
if " --> " in srt_line:
174+
yield srt_line.replace(",", ".")
175+
else:
176+
yield srt_line
177+
178+
def csv_generator(self):
179+
"""Generate the transcript as a string in CSV format, line by line."""
180+
yield "time,text\n"
181+
for time, line in zip(self.time, self.text, strict=True):
182+
yield f"{time},{line}\n"
183+
184+
def to_srt(self) -> str:
185+
"""Return the transcript as a string in SRT format."""
186+
return "".join(self.srt_generator())
187+
188+
def to_vtt(self) -> str:
189+
"""Return the transcript as a string in VTT format."""
190+
return "".join(self.vtt_generator())
191+
192+
def to_json(self) -> str:
193+
"""Return the transcript as a string in JSON format."""
194+
return json.dumps({"text": self.text, "time": self.time})
195+
196+
def to_txt(self) -> str:
197+
"""Return the transcript as a string in TXT format."""
198+
return str(self)
199+
200+
def to_csv(self) -> str:
201+
"""Return the transcript as a string in CSV format."""
202+
return "".join(self.csv_generator())
203+
204+
def write(self, output: StrPath, format: TranscriptFormat) -> None: # noqa: A002
205+
"""Write the transcript to a file.
206+
207+
Args:
208+
output (Path): the path to the output file.
209+
format (TranscriptFormat): the format of the transcript.
210+
"""
211+
method = getattr(self, f"to_{format}")
212+
Path(output).write_text(method())
213+
108214

109215
def to_valid_wav(
110-
source: Path, output: Path | None = None, start: float = 0, end: float | None = None
216+
source: StrPath,
217+
output: StrPath | None = None,
218+
start: float = 0,
219+
end: float | None = None,
111220
) -> Path:
112221
"""Convert a video or audio file to a wav file.
113222
@@ -126,13 +235,16 @@ def to_valid_wav(
126235
Returns:
127236
Path of the converted file.
128237
"""
129-
start, end = int(start * 1000), int(end * 1000) if end is not None else None
238+
source = Path(source)
239+
start = int(start * 1000)
240+
end = int(end * 1000) if end is not None else None
130241
wav_file = source.with_suffix(".wav")
131242
if wav_file == source:
132243
if _is_valid_wav_file(source):
133244
return source
134-
wav_file = source.rename(f"{source.stem}_converted.wav")
135-
output_path = wav_file if output is None else output
245+
wav_file = Path(f'{source.with_suffix("")}_converted.wav')
246+
247+
output_path = wav_file if output is None else Path(output)
136248

137249
args = {"ss": start, "loglevel": "warning"}
138250
if end is not None:
@@ -172,7 +284,7 @@ def parse_data_buffer(
172284

173285

174286
def transcribe(
175-
input_file: Path, model_path: Path, max_size: int | None = None
287+
input_file: StrPath, model_path: StrPath, max_size: int | None = None
176288
) -> Transcript:
177289
"""Transcribe a mono PCM 16-bit WAV file using a vosk model
178290
(https://alphacephei.com/vosk/models).
@@ -190,6 +302,9 @@ def transcribe(
190302
Returns:
191303
Transcript: the transcript of the file
192304
"""
305+
input_file = Path(input_file)
306+
model_path = Path(model_path)
307+
193308
if not input_file.is_file():
194309
msg = f"{input_file} not found"
195310
raise FileNotFoundError(msg)
@@ -209,7 +324,14 @@ def transcribe(
209324

210325
def _is_valid_wav_file(input_file: Path) -> bool:
211326
"""Validate if the input file is a valid WAV file."""
212-
wf = wave.Wave_read(str(input_file))
327+
try:
328+
wf = wave.Wave_read(str(input_file))
329+
except wave.Error as e:
330+
# if it is not a valid wav file for wave_read itself
331+
if "unknown format" in str(e):
332+
return False
333+
raise e from None
334+
213335
is_mono = wf.getnchannels() == 1
214336
is_pcm = wf.getcomptype() == "NONE"
215337
is_16bit = wf.getsampwidth() == 2 # noqa: PLR2004
@@ -218,7 +340,6 @@ def _is_valid_wav_file(input_file: Path) -> bool:
218340

219341
def _initialize_recognizer(model: vosk.Model, input_file: Path) -> vosk.KaldiRecognizer:
220342
"""Initialize the Vosk recognizer."""
221-
# for a weird reason, Wave_read does not work with Path objects
222343
wave_form = wave.Wave_read(str(input_file))
223344
rec = vosk.KaldiRecognizer(model, wave_form.getframerate())
224345

@@ -231,16 +352,21 @@ def _initialize_recognizer(model: vosk.Model, input_file: Path) -> vosk.KaldiRec
231352

232353

233354
def transcribe_with_vosk(
234-
input_file: Path, rec: vosk.KaldiRecognizer, max_size: int | None
355+
input_file: StrPath, rec: vosk.KaldiRecognizer, max_size: int | None
235356
) -> Transcript:
236357
"""Transcribe the file using the Vosk recognizer."""
358+
input_file = Path(input_file)
359+
237360
wave_form = wave.Wave_read(str(input_file))
238361
file_size = input_file.stat().st_size
239362
if max_size is not None and max_size < file_size:
240363
file_size = max_size
241-
pbar = tqdm(total=file_size, unit="B", unit_scale=True)
364+
pbar = tqdm(
365+
total=file_size, unit="B", unit_scale=True, desc=f"Transcribe {input_file}"
366+
)
242367

243-
transcript = Transcript()
368+
time_end = wave_form.getnframes() / wave_form.getframerate()
369+
transcript = Transcript(time_end=time_end)
244370
total_data = 0
245371
len_data = 1 # initialize with 1 to enter the loop
246372
while len_data > 0 and total_data < file_size:
@@ -256,14 +382,26 @@ def transcribe_with_vosk(
256382
return transcript
257383

258384

385+
AllTranscriptFormats = TranscriptFormat | Literal["all"]
386+
387+
259388
class ArgumentParser(tap.Tap):
260389
"""Transcribe a file and optionally translate the transcript."""
261390

262391
input: Path
263392
"the path to the audio file"
264393

265394
output: Path | None = None
266-
"the path to the output file. Default: input file with .txt extension"
395+
"""
396+
the path to the output file. Default: same as the input file with only the extension
397+
changed
398+
"""
399+
400+
format: str = "all"
401+
"""
402+
the format of the transcript. Must be one of 'csv', 'json', 'srt', 'txt', 'vtt'
403+
or 'all'. Default: 'all'
404+
"""
267405

268406
model: Path = Path("model")
269407
"the path to the vosk model"
@@ -291,6 +429,10 @@ class ArgumentParser(tap.Tap):
291429
3: debug. Default: 2."""
292430

293431
def process_args(self):
432+
if self.format not in typing.get_args(AllTranscriptFormats):
433+
msg = f"bad transcript format: {self.format}"
434+
raise ValueError(msg)
435+
294436
vosk.SetLogLevel(-1) # disable vosk logs
295437
match self.verbosity:
296438
case 0:
@@ -309,40 +451,52 @@ def process_args(self):
309451
def configure(self):
310452
self.add_argument("input")
311453
self.add_argument("-o", "--output")
454+
self.add_argument("-f", "--format")
312455
self.add_argument("-m", "--model")
313456
self.add_argument("-li", "--lan_input")
314457
self.add_argument("-lo", "--lan_output")
315458
self.add_argument("-s", "--start")
316459
self.add_argument("-e", "--end")
317460
self.add_argument("-v", "--verbosity")
318461

462+
def get_output(self, fmt: TranscriptFormat) -> Path:
463+
if self.output is None:
464+
if self.format == "all":
465+
return self.input.with_suffix(f".{fmt}")
466+
return self.input.with_suffix(f".{fmt}")
467+
return self.output
468+
469+
def translate(self, transcript: Transcript):
470+
if self.lan_output is None:
471+
return transcript
472+
473+
new_transcript, errors = transcript.translate(self.lan_output)
474+
if errors:
475+
lines = (f"{time} : {line} : {error}" for time, line, error in errors)
476+
logging.warning(f"Errors during the translation: {"\n".join(lines)}")
477+
return new_transcript
478+
319479

320480
# ruff: noqa: G004
321481
def main():
322482
logging.basicConfig(level=logging.INFO)
323483
parser = ArgumentParser()
324484
args = parser.parse_args()
325-
326485
logging.info(f"Convert {args.input} to WAV format")
327486
wav_file = to_valid_wav(args.input, start=args.start, end=args.end)
328487

329488
logging.info(f"Transcribe {wav_file}...")
330489
transcript = transcribe(wav_file, args.model, args.max_size)
331-
transcript.language = args.lan_input
332490

333-
if args.lan_output is not None:
334-
new_transcript, errors = transcript.translate(args.lan_output)
335-
if errors:
336-
lines = (f"{time} : {line} : {error}" for time, line, error in errors)
337-
logging.warning(f"Errors during the translation: {"\n".join(lines)}")
338-
else:
339-
new_transcript = transcript
491+
if not args.keep_wav:
492+
wav_file.unlink()
340493

341-
if args.output is None:
342-
args.output = Path(args.input).with_suffix(".txt")
494+
transcript.language = args.lan_input
343495

344-
with args.output.open("w", encoding="utf-8") as f:
345-
f.write(str(new_transcript))
496+
new_transcript = args.translate(transcript)
346497

347-
if not args.keep_wav:
348-
wav_file.unlink()
498+
if args.format == "all":
499+
for fmt in TRANSCRIPT_FORMATS:
500+
new_transcript.write(args.get_output(fmt), fmt)
501+
else:
502+
new_transcript.write(args.get_output(args.format), args.format)

0 commit comments

Comments
 (0)