Skip to content

Commit 841c546

Browse files
committed
optimzie(train&uvr5): rm sf & simp. AudioPre
1 parent 4b68fb0 commit 841c546

File tree

17 files changed

+111
-279
lines changed

17 files changed

+111
-279
lines changed

infer/lib/audio.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from io import BufferedWriter, BytesIO
22
from pathlib import Path
3-
from typing import Dict, Tuple
3+
from typing import Dict, Tuple, Optional, Union
44
import os
5+
import math
6+
import wave
57

68
import numpy as np
9+
from numba import jit
710
import av
811
from av.audio.resampler import AudioResampler
912

@@ -17,6 +20,26 @@
1720
}
1821

1922

23+
@jit(nopython=True)
24+
def float_to_int16(audio: np.ndarray) -> np.ndarray:
25+
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
26+
am = 32767 * 32768 // am
27+
return np.multiply(audio, am).astype(np.int16)
28+
29+
def float_np_array_to_wav_buf(wav: np.ndarray, sr: int) -> BytesIO:
30+
buf = BytesIO()
31+
with wave.open(buf, "wb") as wf:
32+
wf.setnchannels(1) # Mono channel
33+
wf.setsampwidth(2) # Sample width in bytes
34+
wf.setframerate(sr) # Sample rate in Hz
35+
wf.writeframes(float_to_int16(wav))
36+
buf.seek(0, 0)
37+
return buf
38+
39+
def save_audio(path: str, audio: np.ndarray, sr: int):
40+
with open(path, "wb") as f:
41+
f.write(float_np_array_to_wav_buf(audio, sr).getbuffer())
42+
2043
def wav2(i: BytesIO, o: BufferedWriter, format: str):
2144
inp = av.open(i, "r")
2245
format = video_format_dict.get(format, format)
@@ -36,24 +59,28 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
3659
inp.close()
3760

3861

39-
def load_audio(file: str, sr: int) -> np.ndarray:
40-
if not Path(file).exists():
62+
def load_audio(file: Union[str, BytesIO, Path], sr: Optional[int]=None, format: Optional[str]=None) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
63+
"""
64+
load audio to mono channel
65+
"""
66+
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
4167
raise FileNotFoundError(f"File not found: {file}")
42-
68+
rate = 0
4369
try:
44-
container = av.open(file)
70+
container = av.open(file, format=format)
4571
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
4672

4773
# Estimated maximum total number of samples to pre-allocate the array
4874
# AV stores length in microseconds by default
49-
estimated_total_samples = int(container.duration * sr // 1_000_000)
75+
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
5076
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
5177

5278
offset = 0
5379
for frame in container.decode(audio=0):
5480
frame.pts = None # Clear presentation timestamp to avoid resampling issues
5581
resampled_frames = resampler.resample(frame)
5682
for resampled_frame in resampled_frames:
83+
rate = resampled_frame.rate
5784
frame_data = resampled_frame.to_ndarray()[0]
5885
end_index = offset + len(frame_data)
5986

@@ -69,10 +96,15 @@ def load_audio(file: str, sr: int) -> np.ndarray:
6996
except Exception as e:
7097
raise RuntimeError(f"Failed to load audio: {e}")
7198

72-
return decoded_audio
99+
if sr is not None:
100+
return decoded_audio
101+
return decoded_audio, rate
73102

74103

75-
def downsample_audio(input_path: str, output_path: str, format: str) -> None:
104+
def downsample_audio(input_path: str, output_path: str, format: str, br=128_000) -> None:
105+
"""
106+
default to 128kb/s (equivalent to -q:a 2)
107+
"""
76108
if not os.path.exists(input_path):
77109
return
78110

@@ -83,7 +115,7 @@ def downsample_audio(input_path: str, output_path: str, format: str) -> None:
83115
input_stream = input_container.streams.audio[0]
84116
output_stream = output_container.add_stream(format)
85117

86-
output_stream.bit_rate = 128_000 # 128kb/s (equivalent to -q:a 2)
118+
output_stream.bit_rate = br
87119

88120
# Copy packets from the input file to the output file
89121
for packet in input_container.demux(input_stream):

infer/lib/slicer2.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ def main():
183183
import os.path
184184
from argparse import ArgumentParser
185185

186-
import librosa
187-
import soundfile
186+
from .audio import load_audio, save_audio
188187

189188
parser = ArgumentParser()
190189
parser.add_argument("audio", type=str, help="The audio to be sliced")
@@ -230,7 +229,7 @@ def main():
230229
out = args.out
231230
if out is None:
232231
out = os.path.dirname(os.path.abspath(args.audio))
233-
audio, sr = librosa.load(args.audio, sr=None, mono=False)
232+
audio, sr = load_audio(args.audio)
234233
slicer = Slicer(
235234
sr=sr,
236235
threshold=args.db_thresh,
@@ -245,15 +244,11 @@ def main():
245244
for i, chunk in enumerate(chunks):
246245
if len(chunk.shape) > 1:
247246
chunk = chunk.T
248-
soundfile.write(
249-
os.path.join(
250-
out,
251-
f"%s_%d.wav"
252-
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
253-
),
254-
chunk,
255-
sr,
256-
)
247+
save_audio(os.path.join(
248+
out,
249+
f"%s_%d.wav"
250+
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
251+
), chunk, sr)
257252

258253

259254
if __name__ == "__main__":

infer/modules/train/extract_feature_print.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
import sys
33
import traceback
44

5+
now_dir = os.getcwd()
6+
sys.path.append(now_dir)
7+
8+
from infer.lib.audio import load_audio
9+
510
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
611
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
712

@@ -20,7 +25,6 @@
2025
is_half = sys.argv[7].lower() == "true"
2126
import fairseq
2227
import numpy as np
23-
import soundfile as sf
2428
import torch
2529
import torch.nn.functional as F
2630

@@ -64,7 +68,7 @@ def printt(strr):
6468

6569
# wave must be 16k, hop_size=320
6670
def readwave(wav_path, normalize=False):
67-
wav, sr = sf.read(wav_path)
71+
wav, sr = load_audio(wav_path)
6872
assert sr == 16000
6973
feats = torch.from_numpy(wav).float()
7074
if feats.dim() == 2: # double channels

infer/modules/train/preprocess.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
import os
1717
import traceback
1818

19-
import librosa
2019
import numpy as np
21-
from scipy.io import wavfile
2220

23-
from infer.lib.audio import load_audio
21+
from infer.lib.audio import load_audio, float_np_array_to_wav_buf, save_audio
2422
from infer.lib.slicer2 import Slicer
2523

2624
f = open("%s/preprocess.log" % exp_dir, "a+")
@@ -64,19 +62,15 @@ def norm_write(self, tmp_audio, idx0, idx1):
6462
tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
6563
1 - self.alpha
6664
) * tmp_audio
67-
wavfile.write(
68-
"%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
69-
self.sr,
70-
tmp_audio.astype(np.float32),
71-
)
72-
tmp_audio = librosa.resample(
73-
tmp_audio, orig_sr=self.sr, target_sr=16000
74-
) # , res_type="soxr_vhq"
75-
wavfile.write(
76-
"%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1),
77-
16000,
78-
tmp_audio.astype(np.float32),
79-
)
65+
save_audio("%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1), tmp_audio, self.sr)
66+
with open("%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1), "wb") as f:
67+
f.write(float_np_array_to_wav_buf(
68+
load_audio(
69+
float_np_array_to_wav_buf(tmp_audio, self.sr),
70+
sr=16000,
71+
format="wav"
72+
)
73+
, 16000).getbuffer())
8074

8175
def pipeline(self, path, idx0):
8276
try:

infer/modules/uvr5/mdxnet.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55

66
import librosa
77
import numpy as np
8-
import soundfile as sf
98
import torch
109
from tqdm import tqdm
11-
import av
1210

13-
from infer.lib.audio import downsample_audio
11+
from infer.lib.audio import downsample_audio, save_audio
1412

1513
cpu = torch.device("cpu")
1614

@@ -210,15 +208,13 @@ def prediction(self, m, vocal_root, others_root, format):
210208
sources = self.demix(mix.T)
211209
opt = sources[0].T
212210
if format in ["wav", "flac"]:
213-
sf.write(
214-
"%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate
215-
)
216-
sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
211+
save_audio("%s/vocal_%s.%s" % (vocal_root, basename, format), mix - opt, rate)
212+
save_audio("%s/instrument_%s.%s" % (others_root, basename, format), opt, rate)
217213
else:
218-
path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
219-
path_other = "%s/%s_others.wav" % (others_root, basename)
220-
sf.write(path_vocal, mix - opt, rate)
221-
sf.write(path_other, opt, rate)
214+
path_vocal = "%s/vocal_%s.wav" % (vocal_root, basename)
215+
path_other = "%s/instrument_%s.wav" % (others_root, basename)
216+
save_audio(path_vocal, opt, rate)
217+
save_audio(path_other, opt, rate)
222218
opt_path_vocal = path_vocal[:-4] + ".%s" % format
223219
opt_path_other = path_other[:-4] + ".%s" % format
224220
downsample_audio(path_vocal, opt_path_vocal, format)

infer/modules/uvr5/modules.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from configs import Config
1111
from infer.modules.uvr5.mdxnet import MDXNetDereverb
12-
from infer.modules.uvr5.vr import AudioPre, AudioPreDeEcho
12+
from infer.modules.uvr5.vr import AudioPre
1313

1414
config = Config()
1515

@@ -27,8 +27,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
2727
if model_name == "onnx_dereverb_By_FoxJoy":
2828
pre_fun = MDXNetDereverb(15, config.device)
2929
else:
30-
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
31-
pre_fun = func(
30+
pre_fun = AudioPre(
3231
agg=int(agg),
3332
model_path=os.path.join(
3433
os.getenv("weight_uvr5_root"), model_name + ".pth"
@@ -72,18 +71,10 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
7271
infos.append("%s->Success" % (os.path.basename(inp_path)))
7372
yield "\n".join(infos)
7473
except:
75-
try:
76-
if done == 0:
77-
pre_fun._path_audio_(
78-
inp_path, save_root_ins, save_root_vocal, format0
79-
)
80-
infos.append("%s->Success" % (os.path.basename(inp_path)))
81-
yield "\n".join(infos)
82-
except:
83-
infos.append(
84-
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
85-
)
86-
yield "\n".join(infos)
74+
infos.append(
75+
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
76+
)
77+
yield "\n".join(infos)
8778
except:
8879
infos.append(traceback.format_exc())
8980
yield "\n".join(infos)

0 commit comments

Comments
 (0)