11from io import BufferedWriter , BytesIO
22from pathlib import Path
3- from typing import Dict , Tuple
3+ from typing import Dict , Tuple , Optional , Union , List
44import os
5+ import math
6+ import wave
57
68import numpy as np
9+ from numba import jit
710import av
811from av .audio .resampler import AudioResampler
12+ from av .audio .frame import AudioFrame
13+ import scipy .io .wavfile as wavfile
914
1015video_format_dict : Dict [str , str ] = {
1116 "m4a" : "mp4" ,
1722}
1823
1924
25+ @jit (nopython = True )
26+ def float_to_int16 (audio : np .ndarray ) -> np .ndarray :
27+ am = int (math .ceil (float (np .abs (audio ).max ())) * 32768 )
28+ am = 32767 * 32768 // am
29+ return np .multiply (audio , am ).astype (np .int16 )
30+
31+ def float_np_array_to_wav_buf (wav : np .ndarray , sr : int , f32 = False ) -> BytesIO :
32+ buf = BytesIO ()
33+ if f32 :
34+ wavfile .write (buf , sr , wav .astype (np .float32 ))
35+ else :
36+ with wave .open (buf , "wb" ) as wf :
37+ wf .setnchannels (2 if len (wav .shape ) > 1 else 1 )
38+ wf .setsampwidth (2 ) # Sample width in bytes
39+ wf .setframerate (sr ) # Sample rate in Hz
40+ wf .writeframes (float_to_int16 (wav .T if len (wav .shape ) > 1 else wav ))
41+ buf .seek (0 , 0 )
42+ return buf
43+
44+ def save_audio (path : str , audio : np .ndarray , sr : int , f32 = False ):
45+ with open (path , "wb" ) as f :
46+ f .write (float_np_array_to_wav_buf (audio , sr , f32 ).getbuffer ())
47+
2048def wav2 (i : BytesIO , o : BufferedWriter , format : str ):
2149 inp = av .open (i , "r" )
2250 format = video_format_dict .get (format , format )
@@ -36,43 +64,72 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
3664 inp .close ()
3765
3866
39- def load_audio (file : str , sr : int ) -> np .ndarray :
40- if not Path (file ).exists ():
67+ def load_audio (
68+ file : Union [str , BytesIO , Path ],
69+ sr : Optional [int ]= None ,
70+ format : Optional [str ]= None ,
71+ mono = True
72+ ) -> Union [np .ndarray , Tuple [np .ndarray , int ]]:
73+ if (isinstance (file , str ) and not Path (file ).exists ()) or (isinstance (file , Path ) and not file .exists ()):
4174 raise FileNotFoundError (f"File not found: { file } " )
75+ rate = 0
76+
77+ container = av .open (file , format = format )
78+ audio_stream = next (s for s in container .streams if s .type == "audio" )
79+ channels = 1 if audio_stream .layout == "mono" else 2
80+ container .seek (0 )
81+ resampler = AudioResampler (format = "fltp" , layout = audio_stream .layout , rate = sr ) if sr is not None else None
82+
83+ # Estimated maximum total number of samples to pre-allocate the array
84+ # AV stores length in microseconds by default
85+ estimated_total_samples = int (container .duration * sr // 1_000_000 ) if sr is not None else 48000
86+ decoded_audio = np .zeros (estimated_total_samples + 1 if channels == 1 else (channels , estimated_total_samples + 1 ), dtype = np .float32 )
87+
88+ offset = 0
89+
90+ def process_packet (packet : List [AudioFrame ]):
91+ frames_data = []
92+ rate = 0
93+ for frame in packet :
94+ frame .pts = None # 清除时间戳,避免重新采样问题
95+ resampled_frames = resampler .resample (frame ) if resampler is not None else [frame ]
96+ for resampled_frame in resampled_frames :
97+ frame_data = resampled_frame .to_ndarray ()
98+ rate = resampled_frame .rate
99+ frames_data .append (frame_data )
100+ return (rate , frames_data )
42101
43- try :
44- container = av . open ( file )
45- resampler = AudioResampler ( format = "fltp" , layout = "mono" , rate = sr )
102+ def frame_iter ( container ) :
103+ for p in container . demux ( container . streams . audio [ 0 ]):
104+ yield p . decode ( )
46105
47- # Estimated maximum total number of samples to pre-allocate the array
48- # AV stores length in microseconds by default
49- estimated_total_samples = int ( container . duration * sr // 1_000_000 )
50- decoded_audio = np . zeros ( estimated_total_samples + 1 , dtype = np . float32 )
106+ for r , frames_data in map ( process_packet , frame_iter ( container )):
107+ if not rate : rate = r
108+ for frame_data in frames_data :
109+ end_index = offset + len ( frame_data [ 0 ] )
51110
52- offset = 0
53- for frame in container .decode (audio = 0 ):
54- frame .pts = None # Clear presentation timestamp to avoid resampling issues
55- resampled_frames = resampler .resample (frame )
56- for resampled_frame in resampled_frames :
57- frame_data = resampled_frame .to_ndarray ()[0 ]
58- end_index = offset + len (frame_data )
111+ # 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
112+ if end_index > decoded_audio .shape [1 ]:
113+ decoded_audio = np .resize (decoded_audio , (decoded_audio .shape [0 ], end_index * 4 ))
59114
60- # Check if decoded_audio has enough space, and resize if necessary
61- if end_index > decoded_audio .shape [0 ]:
62- decoded_audio = np .resize (decoded_audio , end_index + 1 )
115+ np .copyto (decoded_audio [..., offset :end_index ], frame_data )
116+ offset += len (frame_data [0 ])
63117
64- decoded_audio [ offset : end_index ] = frame_data
65- offset += len ( frame_data )
118+ # Truncate the array to the actual size
119+ decoded_audio = decoded_audio [..., : offset ]
66120
67- # Truncate the array to the actual size
68- decoded_audio = decoded_audio [:offset ]
69- except Exception as e :
70- raise RuntimeError (f"Failed to load audio: { e } " )
121+ if mono and decoded_audio .shape [0 ] > 1 :
122+ decoded_audio = decoded_audio .mean (0 )
71123
72- return decoded_audio
124+ if sr is not None :
125+ return decoded_audio
126+ return decoded_audio , rate
73127
74128
75- def downsample_audio (input_path : str , output_path : str , format : str ) -> None :
129+ def downsample_audio (input_path : str , output_path : str , format : str , br = 128_000 ) -> None :
130+ """
131+ default to 128kb/s (equivalent to -q:a 2)
132+ """
76133 if not os .path .exists (input_path ):
77134 return
78135
@@ -83,7 +140,7 @@ def downsample_audio(input_path: str, output_path: str, format: str) -> None:
83140 input_stream = input_container .streams .audio [0 ]
84141 output_stream = output_container .add_stream (format )
85142
86- output_stream .bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2)
143+ output_stream .bit_rate = br
87144
88145 # Copy packets from the input file to the output file
89146 for packet in input_container .demux (input_stream ):
@@ -141,7 +198,7 @@ def resample_audio(
141198 print (f"Failed to remove the original file: { e } " )
142199
143200
144- def get_audio_properties (input_path : str ) -> Tuple :
201+ def get_audio_properties (input_path : str ) -> Tuple [ int , int ] :
145202 container = av .open (input_path )
146203 audio_stream = next (s for s in container .streams if s .type == "audio" )
147204 channels = 1 if audio_stream .layout == "mono" else 2
0 commit comments