Skip to content

Commit 17ccf17

Browse files
committed
feat(audio): load_audio support stereo
1 parent e0fffe0 commit 17ccf17

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

infer/lib/audio.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
3030
def float_np_array_to_wav_buf(wav: np.ndarray, sr: int) -> BytesIO:
3131
buf = BytesIO()
3232
with wave.open(buf, "wb") as wf:
33-
wf.setnchannels(1) # Mono channel
33+
wf.setnchannels(2 if len(wav.shape) > 1 else 1) # Mono channel
3434
wf.setsampwidth(2) # Sample width in bytes
3535
wf.setframerate(sr) # Sample rate in Hz
36-
wf.writeframes(float_to_int16(wav))
36+
wf.writeframes(float_to_int16(wav.T if len(wav.shape) > 1 else wav))
3737
buf.seek(0, 0)
3838
return buf
3939

@@ -60,10 +60,12 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
6060
inp.close()
6161

6262

63-
def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format: Optional[str]=None) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
64-
"""
65-
load audio to mono channel
66-
"""
63+
def load_audio(
64+
file: Union[str, BytesIO, Path],
65+
sr: Optional[int]=None,
66+
format: Optional[str]=None,
67+
mono=True
68+
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
6769
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
6870
raise FileNotFoundError(f"File not found: {file}")
6971
rate = 0
@@ -72,7 +74,7 @@ def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format:
7274
audio_stream = next(s for s in container.streams if s.type == "audio")
7375
channels = 1 if audio_stream.layout == "mono" else 2
7476
container.seek(0)
75-
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr)
77+
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) if sr is not None else None
7678

7779
# Estimated maximum total number of samples to pre-allocate the array
7880
# AV stores length in microseconds by default
@@ -83,19 +85,22 @@ def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format:
8385

8486
def process_packet(packet: List[AudioFrame]):
8587
frames_data = []
88+
rate = 0
8689
for frame in packet:
8790
frame.pts = None # 清除时间戳,避免重新采样问题
88-
resampled_frames = resampler.resample(frame)
91+
resampled_frames = resampler.resample(frame) if resampler is not None else [frame]
8992
for resampled_frame in resampled_frames:
9093
frame_data = resampled_frame.to_ndarray()
94+
rate = resampled_frame.rate
9195
frames_data.append(frame_data)
92-
return frames_data
96+
return (rate, frames_data)
9397

9498
def frame_iter(container):
9599
for p in container.demux(container.streams.audio[0]):
96100
yield p.decode()
97101

98-
for frames_data in map(process_packet, frame_iter(container)):
102+
for r, frames_data in map(process_packet, frame_iter(container)):
103+
if not rate: rate = r
99104
for frame_data in frames_data:
100105
end_index = offset + len(frame_data[0])
101106

@@ -109,6 +114,9 @@ def frame_iter(container):
109114
# Truncate the array to the actual size
110115
decoded_audio = decoded_audio[..., :offset]
111116

117+
if mono and decoded_audio.shape[0] > 1:
118+
decoded_audio = decoded_audio.mean(0)
119+
112120
if sr is not None:
113121
return decoded_audio
114122
return decoded_audio, rate

infer/lib/slicer2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def main():
229229
out = args.out
230230
if out is None:
231231
out = os.path.dirname(os.path.abspath(args.audio))
232-
audio, sr = load_audio(args.audio)
232+
audio, sr = load_audio(args.audio, mono=False)
233233
slicer = Slicer(
234234
sr=sr,
235235
threshold=args.db_thresh,

infer/modules/train/extract_feature_print.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ def readwave(wav_path, normalize=False):
7171
wav, sr = load_audio(wav_path)
7272
assert sr == 16000
7373
feats = torch.from_numpy(wav).float()
74-
if feats.dim() == 2: # double channels
75-
feats = feats.mean(-1)
7674
assert feats.dim() == 1, feats.dim()
7775
if normalize:
7876
with torch.no_grad():

infer/modules/train/preprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def norm_write(self, tmp_audio, idx0, idx1):
6868
load_audio(
6969
float_np_array_to_wav_buf(tmp_audio, self.sr),
7070
sr=16000,
71-
format="wav"
71+
format="wav",
72+
mono=False,
7273
)
7374
, 16000).getbuffer())
7475

0 commit comments

Comments
 (0)