11from io import BufferedWriter , BytesIO
22from pathlib import Path
3- from typing import Dict , Tuple
3+ from typing import Dict , Tuple , Optional , Union
44import os
5+ import math
6+ import wave
57
68import numpy as np
9+ from numba import jit
710import av
811from av .audio .resampler import AudioResampler
912
1720}
1821
1922
23+ @jit (nopython = True )
24+ def float_to_int16 (audio : np .ndarray ) -> np .ndarray :
25+ am = int (math .ceil (float (np .abs (audio ).max ())) * 32768 )
26+ am = 32767 * 32768 // am
27+ return np .multiply (audio , am ).astype (np .int16 )
28+
29+ def float_np_array_to_wav_buf (wav : np .ndarray , sr : int ) -> BytesIO :
30+ buf = BytesIO ()
31+ with wave .open (buf , "wb" ) as wf :
32+ wf .setnchannels (1 ) # Mono channel
33+ wf .setsampwidth (2 ) # Sample width in bytes
34+ wf .setframerate (sr ) # Sample rate in Hz
35+ wf .writeframes (float_to_int16 (wav ))
36+ buf .seek (0 , 0 )
37+ return buf
38+
39+ def save_audio (path : str , audio : np .ndarray , sr : int ):
40+ with open (path , "wb" ) as f :
41+ f .write (float_np_array_to_wav_buf (audio , sr ).getbuffer ())
42+
2043def wav2 (i : BytesIO , o : BufferedWriter , format : str ):
2144 inp = av .open (i , "r" )
2245 format = video_format_dict .get (format , format )
@@ -36,24 +59,28 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
3659 inp .close ()
3760
3861
39- def load_audio (file : str , sr : int ) -> np .ndarray :
40- if not Path (file ).exists ():
62+ def load_audio (file : Union [str , BytesIO , Path ], sr : Optional [int ]= None , format : Optional [str ]= None ) -> Union [np .ndarray , Tuple [np .ndarray , int ]]:
63+ """
64+ load audio to mono channel
65+ """
66+ if (isinstance (file , str ) and not Path (file ).exists ()) or (isinstance (file , Path ) and not file .exists ()):
4167 raise FileNotFoundError (f"File not found: { file } " )
42-
68+ rate = 0
4369 try :
44- container = av .open (file )
70+ container = av .open (file , format = format )
4571 resampler = AudioResampler (format = "fltp" , layout = "mono" , rate = sr )
4672
4773 # Estimated maximum total number of samples to pre-allocate the array
4874 # AV stores length in microseconds by default
49- estimated_total_samples = int (container .duration * sr // 1_000_000 )
75+ estimated_total_samples = int (container .duration * sr // 1_000_000 ) if sr is not None else 48000
5076 decoded_audio = np .zeros (estimated_total_samples + 1 , dtype = np .float32 )
5177
5278 offset = 0
5379 for frame in container .decode (audio = 0 ):
5480 frame .pts = None # Clear presentation timestamp to avoid resampling issues
5581 resampled_frames = resampler .resample (frame )
5682 for resampled_frame in resampled_frames :
83+ rate = resampled_frame .rate
5784 frame_data = resampled_frame .to_ndarray ()[0 ]
5885 end_index = offset + len (frame_data )
5986
@@ -69,10 +96,15 @@ def load_audio(file: str, sr: int) -> np.ndarray:
6996 except Exception as e :
7097 raise RuntimeError (f"Failed to load audio: { e } " )
7198
72- return decoded_audio
99+ if sr is not None :
100+ return decoded_audio
101+ return decoded_audio , rate
73102
74103
75- def downsample_audio (input_path : str , output_path : str , format : str ) -> None :
104+ def downsample_audio (input_path : str , output_path : str , format : str , br = 128_000 ) -> None :
105+ """
106+ default to 128kb/s (equivalent to -q:a 2)
107+ """
76108 if not os .path .exists (input_path ):
77109 return
78110
@@ -83,7 +115,7 @@ def downsample_audio(input_path: str, output_path: str, format: str) -> None:
83115 input_stream = input_container .streams .audio [0 ]
84116 output_stream = output_container .add_stream (format )
85117
86- output_stream .bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2)
118+ output_stream .bit_rate = br
87119
88120 # Copy packets from the input file to the output file
89121 for packet in input_container .demux (input_stream ):
0 commit comments