1010import av
1111from av .audio .resampler import AudioResampler
1212from av .audio .frame import AudioFrame
13+ import scipy .io .wavfile as wavfile
1314
1415video_format_dict : Dict [str , str ] = {
1516 "m4a" : "mp4" ,
@@ -27,19 +28,22 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
2728 am = 32767 * 32768 // am
2829 return np .multiply (audio , am ).astype (np .int16 )
2930
30- def float_np_array_to_wav_buf (wav : np .ndarray , sr : int ) -> BytesIO :
31+ def float_np_array_to_wav_buf (wav : np .ndarray , sr : int , f32 = False ) -> BytesIO :
3132 buf = BytesIO ()
32- with wave .open (buf , "wb" ) as wf :
33- wf .setnchannels (2 if len (wav .shape ) > 1 else 1 ) # Mono channel
34- wf .setsampwidth (2 ) # Sample width in bytes
35- wf .setframerate (sr ) # Sample rate in Hz
36- wf .writeframes (float_to_int16 (wav .T if len (wav .shape ) > 1 else wav ))
33+ if f32 :
34+ wavfile .write (buf , sr , wav .astype (np .float32 ))
35+ else :
36+ with wave .open (buf , "wb" ) as wf :
37+ wf .setnchannels (2 if len (wav .shape ) > 1 else 1 )
38+ wf .setsampwidth (2 ) # Sample width in bytes
39+ wf .setframerate (sr ) # Sample rate in Hz
40+ wf .writeframes (float_to_int16 (wav .T if len (wav .shape ) > 1 else wav ))
3741 buf .seek (0 , 0 )
3842 return buf
3943
40- def save_audio (path : str , audio : np .ndarray , sr : int ):
44+ def save_audio (path : str , audio : np .ndarray , sr : int , f32 = False ):
4145 with open (path , "wb" ) as f :
42- f .write (float_np_array_to_wav_buf (audio , sr ).getbuffer ())
46+ f .write (float_np_array_to_wav_buf (audio , sr , f32 ).getbuffer ())
4347
4448def wav2 (i : BytesIO , o : BufferedWriter , format : str ):
4549 inp = av .open (i , "r" )
0 commit comments