Skip to content

Commit e486649

Browse files
committed
optimize(rmvpe): move deepunet&e2e into rvc
1 parent 1e22d46 commit e486649

File tree

4 files changed

+289
-287
lines changed

4 files changed

+289
-287
lines changed

infer/lib/audio.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from io import BufferedWriter, BytesIO
22
from pathlib import Path
33
from typing import Dict, Tuple
4+
import os
5+
46
import numpy as np
57
import av
6-
import os
78
from av.audio.resampler import AudioResampler
89

910
video_format_dict: Dict[str, str] = {
@@ -44,18 +45,16 @@ def load_audio(file: str, sr: int) -> np.ndarray:
4445
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
4546

4647
# Estimated maximum total number of samples to pre-allocate the array
47-
audio_duration_sec: float = (
48-
container.duration / 1_000_000
49-
) # AV stores length in microseconds by default
50-
estimated_total_samples = int(audio_duration_sec * sr + 0.5)
48+
# AV stores length in microseconds by default
49+
estimated_total_samples = int(container.duration * sr // 1_000_000)
5150
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
5251

5352
offset = 0
5453
for frame in container.decode(audio=0):
5554
frame.pts = None # Clear presentation timestamp to avoid resampling issues
5655
resampled_frames = resampler.resample(frame)
5756
for resampled_frame in resampled_frames:
58-
frame_data = np.array(resampled_frame.to_ndarray()).flatten()
57+
frame_data = resampled_frame.to_ndarray()[0]
5958
end_index = offset + len(frame_data)
6059

6160
# Check if decoded_audio has enough space, and resize if necessary

infer/lib/rmvpe.py

Lines changed: 1 addition & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -18,269 +18,13 @@
1818
pass
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from librosa.util import normalize, pad_center, tiny
22-
from scipy.signal import get_window
2321

2422
import logging
2523

2624
logger = logging.getLogger(__name__)
2725

2826
from rvc.f0.mel import MelSpectrogram
29-
30-
from time import time as ttime
31-
32-
33-
class BiGRU(nn.Module):
34-
def __init__(self, input_features, hidden_features, num_layers):
35-
super(BiGRU, self).__init__()
36-
self.gru = nn.GRU(
37-
input_features,
38-
hidden_features,
39-
num_layers=num_layers,
40-
batch_first=True,
41-
bidirectional=True,
42-
)
43-
44-
def forward(self, x):
45-
return self.gru(x)[0]
46-
47-
48-
class ConvBlockRes(nn.Module):
49-
def __init__(self, in_channels, out_channels, momentum=0.01):
50-
super(ConvBlockRes, self).__init__()
51-
self.conv = nn.Sequential(
52-
nn.Conv2d(
53-
in_channels=in_channels,
54-
out_channels=out_channels,
55-
kernel_size=(3, 3),
56-
stride=(1, 1),
57-
padding=(1, 1),
58-
bias=False,
59-
),
60-
nn.BatchNorm2d(out_channels, momentum=momentum),
61-
nn.ReLU(),
62-
nn.Conv2d(
63-
in_channels=out_channels,
64-
out_channels=out_channels,
65-
kernel_size=(3, 3),
66-
stride=(1, 1),
67-
padding=(1, 1),
68-
bias=False,
69-
),
70-
nn.BatchNorm2d(out_channels, momentum=momentum),
71-
nn.ReLU(),
72-
)
73-
# self.shortcut:Optional[nn.Module] = None
74-
if in_channels != out_channels:
75-
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
76-
77-
def forward(self, x: torch.Tensor):
78-
if not hasattr(self, "shortcut"):
79-
return self.conv(x) + x
80-
else:
81-
return self.conv(x) + self.shortcut(x)
82-
83-
84-
class Encoder(nn.Module):
85-
def __init__(
86-
self,
87-
in_channels,
88-
in_size,
89-
n_encoders,
90-
kernel_size,
91-
n_blocks,
92-
out_channels=16,
93-
momentum=0.01,
94-
):
95-
super(Encoder, self).__init__()
96-
self.n_encoders = n_encoders
97-
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
98-
self.layers = nn.ModuleList()
99-
self.latent_channels = []
100-
for i in range(self.n_encoders):
101-
self.layers.append(
102-
ResEncoderBlock(
103-
in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
104-
)
105-
)
106-
self.latent_channels.append([out_channels, in_size])
107-
in_channels = out_channels
108-
out_channels *= 2
109-
in_size //= 2
110-
self.out_size = in_size
111-
self.out_channel = out_channels
112-
113-
def forward(self, x: torch.Tensor):
114-
concat_tensors: List[torch.Tensor] = []
115-
x = self.bn(x)
116-
for i, layer in enumerate(self.layers):
117-
t, x = layer(x)
118-
concat_tensors.append(t)
119-
return x, concat_tensors
120-
121-
122-
class ResEncoderBlock(nn.Module):
123-
def __init__(
124-
self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
125-
):
126-
super(ResEncoderBlock, self).__init__()
127-
self.n_blocks = n_blocks
128-
self.conv = nn.ModuleList()
129-
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
130-
for i in range(n_blocks - 1):
131-
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
132-
self.kernel_size = kernel_size
133-
if self.kernel_size is not None:
134-
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
135-
136-
def forward(self, x):
137-
for i, conv in enumerate(self.conv):
138-
x = conv(x)
139-
if self.kernel_size is not None:
140-
return x, self.pool(x)
141-
else:
142-
return x
143-
144-
145-
class Intermediate(nn.Module): #
146-
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
147-
super(Intermediate, self).__init__()
148-
self.n_inters = n_inters
149-
self.layers = nn.ModuleList()
150-
self.layers.append(
151-
ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
152-
)
153-
for i in range(self.n_inters - 1):
154-
self.layers.append(
155-
ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
156-
)
157-
158-
def forward(self, x):
159-
for i, layer in enumerate(self.layers):
160-
x = layer(x)
161-
return x
162-
163-
164-
class ResDecoderBlock(nn.Module):
165-
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
166-
super(ResDecoderBlock, self).__init__()
167-
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
168-
self.n_blocks = n_blocks
169-
self.conv1 = nn.Sequential(
170-
nn.ConvTranspose2d(
171-
in_channels=in_channels,
172-
out_channels=out_channels,
173-
kernel_size=(3, 3),
174-
stride=stride,
175-
padding=(1, 1),
176-
output_padding=out_padding,
177-
bias=False,
178-
),
179-
nn.BatchNorm2d(out_channels, momentum=momentum),
180-
nn.ReLU(),
181-
)
182-
self.conv2 = nn.ModuleList()
183-
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
184-
for i in range(n_blocks - 1):
185-
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
186-
187-
def forward(self, x, concat_tensor):
188-
x = self.conv1(x)
189-
x = torch.cat((x, concat_tensor), dim=1)
190-
for i, conv2 in enumerate(self.conv2):
191-
x = conv2(x)
192-
return x
193-
194-
195-
class Decoder(nn.Module):
196-
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
197-
super(Decoder, self).__init__()
198-
self.layers = nn.ModuleList()
199-
self.n_decoders = n_decoders
200-
for i in range(self.n_decoders):
201-
out_channels = in_channels // 2
202-
self.layers.append(
203-
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
204-
)
205-
in_channels = out_channels
206-
207-
def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
208-
for i, layer in enumerate(self.layers):
209-
x = layer(x, concat_tensors[-1 - i])
210-
return x
211-
212-
213-
class DeepUnet(nn.Module):
214-
def __init__(
215-
self,
216-
kernel_size,
217-
n_blocks,
218-
en_de_layers=5,
219-
inter_layers=4,
220-
in_channels=1,
221-
en_out_channels=16,
222-
):
223-
super(DeepUnet, self).__init__()
224-
self.encoder = Encoder(
225-
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
226-
)
227-
self.intermediate = Intermediate(
228-
self.encoder.out_channel // 2,
229-
self.encoder.out_channel,
230-
inter_layers,
231-
n_blocks,
232-
)
233-
self.decoder = Decoder(
234-
self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
235-
)
236-
237-
def forward(self, x: torch.Tensor) -> torch.Tensor:
238-
x, concat_tensors = self.encoder(x)
239-
x = self.intermediate(x)
240-
x = self.decoder(x, concat_tensors)
241-
return x
242-
243-
244-
class E2E(nn.Module):
245-
def __init__(
246-
self,
247-
n_blocks,
248-
n_gru,
249-
kernel_size,
250-
en_de_layers=5,
251-
inter_layers=4,
252-
in_channels=1,
253-
en_out_channels=16,
254-
):
255-
super(E2E, self).__init__()
256-
self.unet = DeepUnet(
257-
kernel_size,
258-
n_blocks,
259-
en_de_layers,
260-
inter_layers,
261-
in_channels,
262-
en_out_channels,
263-
)
264-
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
265-
if n_gru:
266-
self.fc = nn.Sequential(
267-
BiGRU(3 * 128, 256, n_gru),
268-
nn.Linear(512, 360),
269-
nn.Dropout(0.25),
270-
nn.Sigmoid(),
271-
)
272-
else:
273-
self.fc = nn.Sequential(
274-
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
275-
)
276-
277-
def forward(self, mel):
278-
# print(mel.shape)
279-
mel = mel.transpose(-1, -2).unsqueeze(1)
280-
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
281-
x = self.fc(x)
282-
# print(x.shape)
283-
return x
27+
from rvc.f0.e2e import E2E
28428

28529

28630
class RMVPE:
@@ -442,27 +186,3 @@ def to_local_average_cents(self, salience, thred=0.05):
442186
# t4 = ttime()
443187
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
444188
return devided
445-
446-
447-
if __name__ == "__main__":
448-
import librosa
449-
import soundfile as sf
450-
451-
audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav")
452-
if len(audio.shape) > 1:
453-
audio = librosa.to_mono(audio.transpose(1, 0))
454-
audio_bak = audio.copy()
455-
if sampling_rate != 16000:
456-
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
457-
model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
458-
thred = 0.03 # 0.01
459-
device = "cuda" if torch.cuda.is_available() else "cpu"
460-
rmvpe = RMVPE(model_path, is_half=False, device=device)
461-
t0 = ttime()
462-
f0 = rmvpe.infer_from_audio(audio, thred=thred)
463-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
464-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
465-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
466-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
467-
t1 = ttime()
468-
logger.info("%s %.2f", f0.shape, t1 - t0)

0 commit comments

Comments
 (0)