@@ -30,10 +30,10 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
3030def float_np_array_to_wav_buf (wav : np .ndarray , sr : int ) -> BytesIO :
3131 buf = BytesIO ()
3232 with wave .open (buf , "wb" ) as wf :
33- wf .setnchannels (1 ) # Mono channel
33+ wf .setnchannels (2 if len ( wav . shape ) > 1 else 1 ) # Mono channel
3434 wf .setsampwidth (2 ) # Sample width in bytes
3535 wf .setframerate (sr ) # Sample rate in Hz
36- wf .writeframes (float_to_int16 (wav ))
36+ wf .writeframes (float_to_int16 (wav . T if len ( wav . shape ) > 1 else wav ))
3737 buf .seek (0 , 0 )
3838 return buf
3939
@@ -60,10 +60,12 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
6060 inp .close ()
6161
6262
63- def load_audio (file : Union [str , BytesIO , Path ], sr : Optional [int ]= None , format : Optional [str ]= None ) -> Union [np .ndarray , Tuple [np .ndarray , int ]]:
64- """
65- load audio to mono channel
66- """
63+ def load_audio (
64+ file : Union [str , BytesIO , Path ],
65+ sr : Optional [int ]= None ,
66+ format : Optional [str ]= None ,
67+ mono = True
68+ ) -> Union [np .ndarray , Tuple [np .ndarray , int ]]:
6769 if (isinstance (file , str ) and not Path (file ).exists ()) or (isinstance (file , Path ) and not file .exists ()):
6870 raise FileNotFoundError (f"File not found: { file } " )
6971 rate = 0
@@ -72,7 +74,7 @@ def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format:
7274 audio_stream = next (s for s in container .streams if s .type == "audio" )
7375 channels = 1 if audio_stream .layout == "mono" else 2
7476 container .seek (0 )
75- resampler = AudioResampler (format = "fltp" , layout = audio_stream .layout , rate = sr )
77+ resampler = AudioResampler (format = "fltp" , layout = audio_stream .layout , rate = sr ) if sr is not None else None
7678
7779 # Estimated maximum total number of samples to pre-allocate the array
7880 # AV stores length in microseconds by default
@@ -83,19 +85,22 @@ def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format:
8385
8486 def process_packet (packet : List [AudioFrame ]):
8587 frames_data = []
88+ rate = 0
8689 for frame in packet :
8790 frame .pts = None # 清除时间戳,避免重新采样问题
88- resampled_frames = resampler .resample (frame )
91+ resampled_frames = resampler .resample (frame ) if resampler is not None else [ frame ]
8992 for resampled_frame in resampled_frames :
9093 frame_data = resampled_frame .to_ndarray ()
94+ rate = resampled_frame .rate
9195 frames_data .append (frame_data )
92- return frames_data
96+ return ( rate , frames_data )
9397
9498 def frame_iter (container ):
9599 for p in container .demux (container .streams .audio [0 ]):
96100 yield p .decode ()
97101
98- for frames_data in map (process_packet , frame_iter (container )):
102+ for r , frames_data in map (process_packet , frame_iter (container )):
103+ if not rate : rate = r
99104 for frame_data in frames_data :
100105 end_index = offset + len (frame_data [0 ])
101106
@@ -109,6 +114,9 @@ def frame_iter(container):
109114 # Truncate the array to the actual size
110115 decoded_audio = decoded_audio [..., :offset ]
111116
117+ if mono and decoded_audio .shape [0 ] > 1 :
118+ decoded_audio = decoded_audio .mean (0 )
119+
112120 if sr is not None :
113121 return decoded_audio
114122 return decoded_audio , rate
0 commit comments