22
33import json
44import logging
5+ import typing
56import warnings
67import wave
78from dataclasses import dataclass , field
89from pathlib import Path
9- from typing import NamedTuple
10+ from typing import TYPE_CHECKING , Literal , NamedTuple
1011
1112import deep_translator
1213import ffmpeg
1314import tap # typed_argument_parser
1415import vosk
1516from tqdm import tqdm
1617
18+ if TYPE_CHECKING :
19+ from os import PathLike
20+ type StrPath = str | PathLike [str ]
21+
22+
23+ TranscriptFormat = Literal ["csv" , "json" , "srt" , "txt" , "vtt" ]
24+ TRANSCRIPT_FORMATS : tuple [TranscriptFormat ] = typing .get_args (TranscriptFormat )
25+
1726
1827class LineError (NamedTuple ):
1928 time : float
@@ -51,6 +60,23 @@ def seconds_to_time(seconds: float) -> str:
5160 return f"{ days } d { hours :02d} :{ seconds_to_time (minutes )} "
5261
5362
63+ def seconds_to_srt_time (seconds : float ) -> str :
64+ """Convert seconds to SRT time format.
65+
66+ Args:
67+ seconds (float): the number of seconds
68+
69+ Returns:
70+ str: the time in the format "hh:mm:ss,ms"
71+ """
72+ seconds = float (seconds )
73+ hours , remainder = divmod (seconds , 3600 )
74+ minutes , seconds = divmod (remainder , 60 )
75+ int_seconds , dec_seconds = str (seconds ).split ("." )
76+ dec_seconds = dec_seconds [:3 ]
77+ return f"{ int (hours ):02d} :{ int (minutes ):02d} :{ int (int_seconds ):02d} ,{ dec_seconds } "
78+
79+
5480@dataclass
5581class Transcript :
5682 """A transcript of a video or audio file.
@@ -60,11 +86,20 @@ class Transcript:
6086 time (list[float]): a list of the time of each line in the transcript.
6187 text (list[str]): a list of the text of each line in the transcript.
6288 language (str): the language of the transcript. Default: "auto"
89+ time_end (float): the time of the last line in the transcript. If not specified,
90+ it will be the time of the last line + 5 seconds.
6391 """
6492
6593 time : list [float ] = field (default_factory = list )
6694 text : list [str ] = field (default_factory = list )
6795 language : str = "auto"
96+ time_end : float | None = None
97+
98+ @property
99+ def _time_end (self ) -> float :
100+ if self .time_end is not None :
101+ return self .time_end
102+ return self .time [- 1 ] + 5
68103
69104 def append (self , time : float , text : str ) -> None :
70105 self .time .append (time )
@@ -76,6 +111,12 @@ def __str__(self):
76111 for time , line in zip (self .time , self .text , strict = True )
77112 )
78113
114+ def __len__ (self ):
115+ return len (self .time )
116+
117+ def __getitem__ (self , key ):
118+ return self .time [key ], self .text [key ]
119+
79120 def translate (self , target : str ) -> tuple [Transcript , list [LineError ]]:
80121 """Return a translated version of the transcript.
81122
@@ -89,11 +130,13 @@ def translate(self, target: str) -> tuple[Transcript, list[LineError]]:
89130 - Transcript: the translated transcript
90131 - list[LineError]: a list of errors that occurred during the translation
91132 """
92- translated = Transcript ()
133+ translated = Transcript (time_end = self . time_end )
93134 errors : list [LineError ] = []
94135
95136 _iter = zip (self .time , self .text , strict = True )
96- pbar = tqdm (_iter , total = len (self .time ), unit_scale = True , unit = "line" )
137+ pbar = tqdm (
138+ _iter , total = len (self .time ), unit_scale = True , unit = "line" , desc = "Translate"
139+ )
97140 for time , line in pbar :
98141 try :
99142 translator = deep_translator .GoogleTranslator (
@@ -105,9 +148,75 @@ def translate(self, target: str) -> tuple[Transcript, list[LineError]]:
105148 errors .append (LineError (time , line , e ))
106149 return translated , errors
107150
151+ def srt_generator (self ):
152+ """Generate the transcript as a string in SRT format, line by line."""
153+
154+ def one_line (start , end , line ):
155+ start , end = map (seconds_to_srt_time , (start , end ))
156+ return f"{ start } --> { end } \n { line } \n \n "
157+
158+ nb_lines = len (self )
159+ for i , (time , line ) in enumerate (
160+ zip (self .time , self .text , strict = True ), start = 1
161+ ):
162+ if i == nb_lines :
163+ end = self ._time_end
164+ else :
165+ after_time = self .time [i ]
166+ end = min (after_time , time + 5 )
167+ yield one_line (time , end , line )
168+
169+ def vtt_generator (self ):
170+ """Generate the transcript as a string in VTT format, line by line."""
171+ yield "WEBVTT\n \n "
172+ for srt_line in self .srt_generator ():
173+ if " --> " in srt_line :
174+ yield srt_line .replace ("," , "." )
175+ else :
176+ yield srt_line
177+
178+ def csv_generator (self ):
179+ """Generate the transcript as a string in CSV format, line by line."""
180+ yield "time,text\n "
181+ for time , line in zip (self .time , self .text , strict = True ):
182+ yield f"{ time } ,{ line } \n "
183+
184+ def to_srt (self ) -> str :
185+ """Return the transcript as a string in SRT format."""
186+ return "" .join (self .srt_generator ())
187+
188+ def to_vtt (self ) -> str :
189+ """Return the transcript as a string in VTT format."""
190+ return "" .join (self .vtt_generator ())
191+
192+ def to_json (self ) -> str :
193+ """Return the transcript as a string in JSON format."""
194+ return json .dumps ({"text" : self .text , "time" : self .time })
195+
196+ def to_txt (self ) -> str :
197+ """Return the transcript as a string in TXT format."""
198+ return str (self )
199+
200+ def to_csv (self ) -> str :
201+ """Return the transcript as a string in CSV format."""
202+ return "" .join (self .csv_generator ())
203+
204+ def write (self , output : StrPath , format : TranscriptFormat ) -> None : # noqa: A002
205+ """Write the transcript to a file.
206+
207+ Args:
208+ output (Path): the path to the output file.
209+ format (TranscriptFormat): the format of the transcript.
210+ """
211+ method = getattr (self , f"to_{ format } " )
212+ Path (output ).write_text (method ())
213+
108214
109215def to_valid_wav (
110- source : Path , output : Path | None = None , start : float = 0 , end : float | None = None
216+ source : StrPath ,
217+ output : StrPath | None = None ,
218+ start : float = 0 ,
219+ end : float | None = None ,
111220) -> Path :
112221 """Convert a video or audio file to a wav file.
113222
@@ -126,13 +235,16 @@ def to_valid_wav(
126235 Returns:
127236 Path of the converted file.
128237 """
129- start , end = int (start * 1000 ), int (end * 1000 ) if end is not None else None
238+ source = Path (source )
239+ start = int (start * 1000 )
240+ end = int (end * 1000 ) if end is not None else None
130241 wav_file = source .with_suffix (".wav" )
131242 if wav_file == source :
132243 if _is_valid_wav_file (source ):
133244 return source
134- wav_file = source .rename (f"{ source .stem } _converted.wav" )
135- output_path = wav_file if output is None else output
245+ wav_file = Path (f'{ source .with_suffix ("" )} _converted.wav' )
246+
247+ output_path = wav_file if output is None else Path (output )
136248
137249 args = {"ss" : start , "loglevel" : "warning" }
138250 if end is not None :
@@ -172,7 +284,7 @@ def parse_data_buffer(
172284
173285
174286def transcribe (
175- input_file : Path , model_path : Path , max_size : int | None = None
287+ input_file : StrPath , model_path : StrPath , max_size : int | None = None
176288) -> Transcript :
177289 """Transcribe a mono PCM 16-bit WAV file using a vosk model
178290 (https://alphacephei.com/vosk/models).
@@ -190,6 +302,9 @@ def transcribe(
190302 Returns:
191303 Transcript: the transcript of the file
192304 """
305+ input_file = Path (input_file )
306+ model_path = Path (model_path )
307+
193308 if not input_file .is_file ():
194309 msg = f"{ input_file } not found"
195310 raise FileNotFoundError (msg )
@@ -209,7 +324,14 @@ def transcribe(
209324
210325def _is_valid_wav_file (input_file : Path ) -> bool :
211326 """Validate if the input file is a valid WAV file."""
212- wf = wave .Wave_read (str (input_file ))
327+ try :
328+ wf = wave .Wave_read (str (input_file ))
329+ except wave .Error as e :
330+ # if it is not a valid wav file for wave_read itself
331+ if "unknown format" in str (e ):
332+ return False
333+ raise e from None
334+
213335 is_mono = wf .getnchannels () == 1
214336 is_pcm = wf .getcomptype () == "NONE"
215337 is_16bit = wf .getsampwidth () == 2 # noqa: PLR2004
@@ -218,7 +340,6 @@ def _is_valid_wav_file(input_file: Path) -> bool:
218340
219341def _initialize_recognizer (model : vosk .Model , input_file : Path ) -> vosk .KaldiRecognizer :
220342 """Initialize the Vosk recognizer."""
221- # for a weird reason, Wave_read does not work with Path objects
222343 wave_form = wave .Wave_read (str (input_file ))
223344 rec = vosk .KaldiRecognizer (model , wave_form .getframerate ())
224345
@@ -231,16 +352,21 @@ def _initialize_recognizer(model: vosk.Model, input_file: Path) -> vosk.KaldiRec
231352
232353
233354def transcribe_with_vosk (
234- input_file : Path , rec : vosk .KaldiRecognizer , max_size : int | None
355+ input_file : StrPath , rec : vosk .KaldiRecognizer , max_size : int | None
235356) -> Transcript :
236357 """Transcribe the file using the Vosk recognizer."""
358+ input_file = Path (input_file )
359+
237360 wave_form = wave .Wave_read (str (input_file ))
238361 file_size = input_file .stat ().st_size
239362 if max_size is not None and max_size < file_size :
240363 file_size = max_size
241- pbar = tqdm (total = file_size , unit = "B" , unit_scale = True )
364+ pbar = tqdm (
365+ total = file_size , unit = "B" , unit_scale = True , desc = f"Transcribe { input_file } "
366+ )
242367
243- transcript = Transcript ()
368+ time_end = wave_form .getnframes () / wave_form .getframerate ()
369+ transcript = Transcript (time_end = time_end )
244370 total_data = 0
245371 len_data = 1 # initialize with 1 to enter the loop
246372 while len_data > 0 and total_data < file_size :
@@ -256,14 +382,26 @@ def transcribe_with_vosk(
256382 return transcript
257383
258384
385+ AllTranscriptFormats = TranscriptFormat | Literal ["all" ]
386+
387+
259388class ArgumentParser (tap .Tap ):
260389 """Transcribe a file and optionally translate the transcript."""
261390
262391 input : Path
263392 "the path to the audio file"
264393
265394 output : Path | None = None
266- "the path to the output file. Default: input file with .txt extension"
395+ """
396+ the path to the output file. Default: same as the input file with only the extension
397+ changed
398+ """
399+
400+ format : str = "all"
401+ """
402+ the format of the transcript. Must be one of 'csv', 'json', 'srt', 'txt', 'vtt'
403+ or 'all'. Default: 'all'
404+ """
267405
268406 model : Path = Path ("model" )
269407 "the path to the vosk model"
@@ -291,6 +429,10 @@ class ArgumentParser(tap.Tap):
291429 3: debug. Default: 2."""
292430
293431 def process_args (self ):
432+ if self .format not in typing .get_args (AllTranscriptFormats ):
433+ msg = f"bad transcript format: { self .format } "
434+ raise ValueError (msg )
435+
294436 vosk .SetLogLevel (- 1 ) # disable vosk logs
295437 match self .verbosity :
296438 case 0 :
@@ -309,40 +451,52 @@ def process_args(self):
309451 def configure (self ):
310452 self .add_argument ("input" )
311453 self .add_argument ("-o" , "--output" )
454+ self .add_argument ("-f" , "--format" )
312455 self .add_argument ("-m" , "--model" )
313456 self .add_argument ("-li" , "--lan_input" )
314457 self .add_argument ("-lo" , "--lan_output" )
315458 self .add_argument ("-s" , "--start" )
316459 self .add_argument ("-e" , "--end" )
317460 self .add_argument ("-v" , "--verbosity" )
318461
462+ def get_output (self , fmt : TranscriptFormat ) -> Path :
463+ if self .output is None :
464+ if self .format == "all" :
465+ return self .input .with_suffix (f".{ fmt } " )
466+ return self .input .with_suffix (f".{ fmt } " )
467+ return self .output
468+
469+ def translate (self , transcript : Transcript ):
470+ if self .lan_output is None :
471+ return transcript
472+
473+ new_transcript , errors = transcript .translate (self .lan_output )
474+ if errors :
475+ lines = (f"{ time } : { line } : { error } " for time , line , error in errors )
476+ logging .warning (f"Errors during the translation: { "\n " .join (lines )} " )
477+ return new_transcript
478+
319479
320480# ruff: noqa: G004
321481def main ():
322482 logging .basicConfig (level = logging .INFO )
323483 parser = ArgumentParser ()
324484 args = parser .parse_args ()
325-
326485 logging .info (f"Convert { args .input } to WAV format" )
327486 wav_file = to_valid_wav (args .input , start = args .start , end = args .end )
328487
329488 logging .info (f"Transcribe { wav_file } ..." )
330489 transcript = transcribe (wav_file , args .model , args .max_size )
331- transcript .language = args .lan_input
332490
333- if args .lan_output is not None :
334- new_transcript , errors = transcript .translate (args .lan_output )
335- if errors :
336- lines = (f"{ time } : { line } : { error } " for time , line , error in errors )
337- logging .warning (f"Errors during the translation: { "\n " .join (lines )} " )
338- else :
339- new_transcript = transcript
491+ if not args .keep_wav :
492+ wav_file .unlink ()
340493
341- if args .output is None :
342- args .output = Path (args .input ).with_suffix (".txt" )
494+ transcript .language = args .lan_input
343495
344- with args .output .open ("w" , encoding = "utf-8" ) as f :
345- f .write (str (new_transcript ))
496+ new_transcript = args .translate (transcript )
346497
347- if not args .keep_wav :
348- wav_file .unlink ()
498+ if args .format == "all" :
499+ for fmt in TRANSCRIPT_FORMATS :
500+ new_transcript .write (args .get_output (fmt ), fmt )
501+ else :
502+ new_transcript .write (args .get_output (args .format ), args .format )
0 commit comments