Skip to content

Commit 75b6ab6

Browse files
committed
fix(audio): float32 wav saving
1 parent 17ccf17 commit 75b6ab6

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

infer/lib/audio.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import av
1111
from av.audio.resampler import AudioResampler
1212
from av.audio.frame import AudioFrame
13+
import scipy.io.wavfile as wavfile
1314

1415
video_format_dict: Dict[str, str] = {
1516
"m4a": "mp4",
@@ -27,19 +28,22 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
2728
am = 32767 * 32768 // am
2829
return np.multiply(audio, am).astype(np.int16)
2930

30-
def float_np_array_to_wav_buf(wav: np.ndarray, sr: int) -> BytesIO:
31+
def float_np_array_to_wav_buf(wav: np.ndarray, sr: int, f32=False) -> BytesIO:
3132
buf = BytesIO()
32-
with wave.open(buf, "wb") as wf:
33-
wf.setnchannels(2 if len(wav.shape) > 1 else 1) # Mono channel
34-
wf.setsampwidth(2) # Sample width in bytes
35-
wf.setframerate(sr) # Sample rate in Hz
36-
wf.writeframes(float_to_int16(wav.T if len(wav.shape) > 1 else wav))
33+
if f32:
34+
wavfile.write(buf, sr, wav.astype(np.float32))
35+
else:
36+
with wave.open(buf, "wb") as wf:
37+
wf.setnchannels(2 if len(wav.shape) > 1 else 1)
38+
wf.setsampwidth(2) # Sample width in bytes
39+
wf.setframerate(sr) # Sample rate in Hz
40+
wf.writeframes(float_to_int16(wav.T if len(wav.shape) > 1 else wav))
3741
buf.seek(0, 0)
3842
return buf
3943

40-
def save_audio(path: str, audio: np.ndarray, sr: int):
44+
def save_audio(path: str, audio: np.ndarray, sr: int, f32=False):
4145
with open(path, "wb") as f:
42-
f.write(float_np_array_to_wav_buf(audio, sr).getbuffer())
46+
f.write(float_np_array_to_wav_buf(audio, sr, f32).getbuffer())
4347

4448
def wav2(i: BytesIO, o: BufferedWriter, format: str):
4549
inp = av.open(i, "r")

infer/modules/train/preprocess.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,15 @@ def norm_write(self, tmp_audio, idx0, idx1):
6262
tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
6363
1 - self.alpha
6464
) * tmp_audio
65-
save_audio("%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1), tmp_audio, self.sr)
65+
save_audio("%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1), tmp_audio, self.sr, f32=True)
6666
with open("%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1), "wb") as f:
6767
f.write(float_np_array_to_wav_buf(
6868
load_audio(
69-
float_np_array_to_wav_buf(tmp_audio, self.sr),
69+
float_np_array_to_wav_buf(tmp_audio, self.sr, f32=True),
7070
sr=16000,
7171
format="wav",
72-
mono=False,
7372
)
74-
, 16000).getbuffer())
73+
, 16000, True).getbuffer())
7574

7675
def pipeline(self, path, idx0):
7776
try:

web.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def forward_dml(ctx, x, scale):
141141
index_root = os.getenv("index_root")
142142
outside_index_root = os.getenv("outside_index_root")
143143

144-
names = []
145-
index_paths = []
144+
names = [""]
145+
index_paths = [""]
146146

147147
def lookup_names(weight_root):
148148
global names
@@ -168,9 +168,9 @@ def lookup_indices(index_root):
168168

169169
def change_choices():
170170
global index_paths, names
171-
names = []
171+
names = [""]
172172
lookup_names(weight_root)
173-
index_paths = []
173+
index_paths = [""]
174174
lookup_indices(index_root)
175175
lookup_indices(outside_index_root)
176176
return {"choices": sorted(names), "__type__": "update"}, {

0 commit comments

Comments
 (0)