Skip to content

Commit 5969314

Browse files
committed
optimize(uvr5): apply jit to spec_utils & fix flac save
also fix #85
1 parent 4582d4b commit 5969314

File tree

11 files changed

+104
-581
lines changed

11 files changed

+104
-581
lines changed

infer/lib/audio.py

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,14 @@ def float_np_array_to_wav_buf(wav: np.ndarray, sr: int, f32=False) -> BytesIO:
4343
return buf
4444

4545

46-
def save_audio(path: str, audio: np.ndarray, sr: int, f32=False):
46+
def save_audio(path: str, audio: np.ndarray, sr: int, f32=False, format="wav"):
47+
buf = float_np_array_to_wav_buf(audio, sr, f32)
48+
if format != "wav":
49+
transbuf = BytesIO()
50+
wav2(buf, transbuf, format)
51+
buf = transbuf
4752
with open(path, "wb") as f:
48-
f.write(float_np_array_to_wav_buf(audio, sr, f32).getbuffer())
53+
f.write(buf.getbuffer())
4954

5055

5156
def wav2(i: BytesIO, o: BufferedWriter, format: str):
@@ -109,7 +114,7 @@ def process_packet(packet: List[AudioFrame]):
109114
frames_data = []
110115
rate = 0
111116
for frame in packet:
112-
frame.pts = None # 清除时间戳,避免重新采样问题
117+
# frame.pts = None # 清除时间戳,避免重新采样问题
113118
resampled_frames = (
114119
resampler.resample(frame) if resampler is not None else [frame]
115120
)
@@ -137,6 +142,8 @@ def frame_iter(container):
137142

138143
np.copyto(decoded_audio[..., offset:end_index], frame_data)
139144
offset += len(frame_data[0])
145+
146+
container.close()
140147

141148
# Truncate the array to the actual size
142149
decoded_audio = decoded_audio[..., :offset]
@@ -149,43 +156,6 @@ def frame_iter(container):
149156
return decoded_audio, rate
150157

151158

152-
def downsample_audio(
153-
input_path: str, output_path: str, format: str, br=128_000
154-
) -> None:
155-
"""
156-
default to 128kb/s (equivalent to -q:a 2)
157-
"""
158-
if not os.path.exists(input_path):
159-
return
160-
161-
input_container = av.open(input_path)
162-
output_container = av.open(output_path, "w")
163-
164-
# Create a stream in the output container
165-
input_stream = input_container.streams.audio[0]
166-
output_stream = output_container.add_stream(format)
167-
168-
output_stream.bit_rate = br
169-
170-
# Copy packets from the input file to the output file
171-
for packet in input_container.demux(input_stream):
172-
for frame in packet.decode():
173-
for out_packet in output_stream.encode(frame):
174-
output_container.mux(out_packet)
175-
176-
for packet in output_stream.encode():
177-
output_container.mux(packet)
178-
179-
# Close the containers
180-
input_container.close()
181-
output_container.close()
182-
183-
try: # Remove the original file
184-
os.remove(input_path)
185-
except Exception as e:
186-
print(f"Failed to remove the original file: {e}")
187-
188-
189159
def resample_audio(
190160
input_path: str, output_path: str, codec: str, format: str, sr: int, layout: str
191161
) -> None:
@@ -204,7 +174,7 @@ def resample_audio(
204174
# Copy packets from the input file to the output file
205175
for packet in input_container.demux(input_stream):
206176
for frame in packet.decode():
207-
frame.pts = None # Clear presentation timestamp to avoid resampling issues
177+
# frame.pts = None # Clear presentation timestamp to avoid resampling issues
208178
out_frames = resampler.resample(frame)
209179
for out_frame in out_frames:
210180
for out_packet in output_stream.encode(out_frame):
@@ -217,10 +187,6 @@ def resample_audio(
217187
input_container.close()
218188
output_container.close()
219189

220-
try: # Remove the original file
221-
os.remove(input_path)
222-
except Exception as e:
223-
print(f"Failed to remove the original file: {e}")
224190

225191

226192
def get_audio_properties(input_path: str) -> Tuple[int, int]:

infer/lib/train/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import sys
77
from copy import deepcopy
8+
import math
89

910
import codecs
1011
import numpy as np
@@ -103,7 +104,7 @@ def summarize(
103104

104105
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
105106
f_list = glob.glob(os.path.join(dir_path, regex))
106-
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
107+
f_list.sort(key=lambda f: 999999999999 if isinstance(f, str) and f == "latest" else int("0"+"".join(filter(str.isdigit, f))))
107108
x = f_list[-1]
108109
logger.debug(x)
109110
return x

infer/lib/uvr5_pack/lib_v5/dataset.py

Lines changed: 0 additions & 183 deletions
This file was deleted.

infer/lib/uvr5_pack/lib_v5/layers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReL
2222
activ(),
2323
)
2424

25-
def __call__(self, x):
25+
@torch.inference_mode()
26+
def forward(self, x):
2627
return self.conv(x)
2728

2829

@@ -32,7 +33,8 @@ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
3233
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
3334
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
3435

35-
def __call__(self, x):
36+
@torch.inference_mode()
37+
def forward(self, x):
3638
h = self.conv1(x)
3739
h = self.conv2(h)
3840

@@ -48,7 +50,8 @@ def __init__(
4850
# self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
4951
self.dropout = nn.Dropout2d(0.1) if dropout else None
5052

51-
def __call__(self, x, skip=None):
53+
@torch.inference_mode()
54+
def forward(self, x, skip=None):
5255
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
5356

5457
if skip is not None:
@@ -84,6 +87,7 @@ def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False
8487
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
8588
self.dropout = nn.Dropout2d(0.1) if dropout else None
8689

90+
@torch.inference_mode()
8791
def forward(self, x):
8892
_, _, h, w = x.size()
8993
feat1 = F.interpolate(
@@ -113,6 +117,7 @@ def __init__(self, nin_conv, nin_lstm, nout_lstm):
113117
nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
114118
)
115119

120+
@torch.inference_mode()
116121
def forward(self, x):
117122
N, _, nbins, nframes = x.size()
118123
h = self.conv(x)[:, 0] # N, nbins, nframes

infer/lib/uvr5_pack/lib_v5/nets.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(
2424
self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm)
2525
self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
2626

27-
def __call__(self, x):
27+
@torch.inference_mode()
28+
def forward(self, x):
2829
e1 = self.enc1(x)
2930
e2 = self.enc2(e1)
3031
e3 = self.enc3(e2)
@@ -75,6 +76,7 @@ def __init__(self, n_fft, nout=32, nout_lstm=128):
7576
self.out = nn.Conv2d(nout, 2, 1, bias=False)
7677
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
7778

79+
@torch.inference_mode()
7880
def forward(self, x):
7981
x = x[:, :, : self.max_bin]
8082

@@ -112,22 +114,3 @@ def forward(self, x):
112114
return mask, aux
113115
else:
114116
return mask
115-
116-
def predict_mask(self, x):
117-
mask = self.forward(x)
118-
119-
if self.offset > 0:
120-
mask = mask[:, :, :, self.offset : -self.offset]
121-
assert mask.size()[3] > 0
122-
123-
return mask
124-
125-
def predict(self, x, aggressiveness=None):
126-
mask = self.forward(x)
127-
pred_mag = x * mask
128-
129-
if self.offset > 0:
130-
pred_mag = pred_mag[:, :, :, self.offset : -self.offset]
131-
assert pred_mag.size()[3] > 0
132-
133-
return pred_mag

0 commit comments

Comments
 (0)