Skip to content

Commit 20fb86a

Browse files
authored
Add files via upload
1 parent 0fcc293 commit 20fb86a

File tree

1 file changed

+271
-48
lines changed

1 file changed

+271
-48
lines changed

lib/rmvpe.py

Lines changed: 271 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,197 @@
1-
import torch, numpy as np
1+
import torch, numpy as np,pdb
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
5-
4+
import torch,pdb
5+
import numpy as np
6+
import torch.nn.functional as F
7+
from scipy.signal import get_window
8+
from librosa.util import pad_center, tiny,normalize
9+
###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
10+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
11+
n_fft=800, dtype=np.float32, norm=None):
12+
"""
13+
# from librosa 0.6
14+
Compute the sum-square envelope of a window function at a given hop length.
15+
This is used to estimate modulation effects induced by windowing
16+
observations in short-time fourier transforms.
17+
Parameters
18+
----------
19+
window : string, tuple, number, callable, or list-like
20+
Window specification, as in `get_window`
21+
n_frames : int > 0
22+
The number of analysis frames
23+
hop_length : int > 0
24+
The number of samples to advance between frames
25+
win_length : [optional]
26+
The length of the window function. By default, this matches `n_fft`.
27+
n_fft : int > 0
28+
The length of each analysis frame.
29+
dtype : np.dtype
30+
The data type of the output
31+
Returns
32+
-------
33+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
34+
The sum-squared envelope of the window function
35+
"""
36+
if win_length is None:
37+
win_length = n_fft
38+
39+
n = n_fft + hop_length * (n_frames - 1)
40+
x = np.zeros(n, dtype=dtype)
41+
42+
# Compute the squared window at the desired length
43+
win_sq = get_window(window, win_length, fftbins=True)
44+
win_sq = normalize(win_sq, norm=norm)**2
45+
win_sq = pad_center(win_sq, n_fft)
46+
47+
# Fill the envelope
48+
for i in range(n_frames):
49+
sample = i * hop_length
50+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
51+
return x
52+
53+
class STFT(torch.nn.Module):
54+
def __init__(self, filter_length=1024, hop_length=512, win_length=None,
55+
window='hann'):
56+
"""
57+
This module implements an STFT using 1D convolution and 1D transpose convolutions.
58+
This is a bit tricky so there are some cases that probably won't work as working
59+
out the same sizes before and after in all overlap add setups is tough. Right now,
60+
this code should work with hop lengths that are half the filter length (50% overlap
61+
between frames).
62+
63+
Keyword Arguments:
64+
filter_length {int} -- Length of filters used (default: {1024})
65+
hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
66+
win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
67+
equals the filter length). (default: {None})
68+
window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
69+
(default: {'hann'})
70+
"""
71+
super(STFT, self).__init__()
72+
self.filter_length = filter_length
73+
self.hop_length = hop_length
74+
self.win_length = win_length if win_length else filter_length
75+
self.window = window
76+
self.forward_transform = None
77+
self.pad_amount = int(self.filter_length / 2)
78+
scale = self.filter_length / self.hop_length
79+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
80+
81+
cutoff = int((self.filter_length / 2 + 1))
82+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),np.imag(fourier_basis[:cutoff, :])])
83+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
84+
inverse_basis = torch.FloatTensor(
85+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
86+
87+
assert (filter_length >= self.win_length)
88+
# get window and zero center pad it to filter_length
89+
fft_window = get_window(window, self.win_length, fftbins=True)
90+
fft_window = pad_center(fft_window, size=filter_length)
91+
fft_window = torch.from_numpy(fft_window).float()
92+
93+
# window the bases
94+
forward_basis *= fft_window
95+
inverse_basis *= fft_window
96+
97+
self.register_buffer('forward_basis', forward_basis.float())
98+
self.register_buffer('inverse_basis', inverse_basis.float())
99+
100+
def transform(self, input_data):
101+
"""Take input data (audio) to STFT domain.
102+
103+
Arguments:
104+
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
105+
106+
Returns:
107+
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
108+
num_frequencies, num_frames)
109+
phase {tensor} -- Phase of STFT with shape (num_batch,
110+
num_frequencies, num_frames)
111+
"""
112+
num_batches = input_data.shape[0]
113+
num_samples = input_data.shape[-1]
114+
115+
self.num_samples = num_samples
116+
117+
# similar to librosa, reflect-pad the input
118+
input_data = input_data.view(num_batches, 1, num_samples)
119+
# print(1234,input_data.shape)
120+
input_data = F.pad(input_data.unsqueeze(1),(self.pad_amount, self.pad_amount, 0, 0,0,0),mode='reflect').squeeze(1)
121+
# print(2333,input_data.shape,self.forward_basis.shape,self.hop_length)
122+
# pdb.set_trace()
123+
forward_transform = F.conv1d(
124+
input_data,
125+
self.forward_basis,
126+
stride=self.hop_length,
127+
padding=0)
128+
129+
cutoff = int((self.filter_length / 2) + 1)
130+
real_part = forward_transform[:, :cutoff, :]
131+
imag_part = forward_transform[:, cutoff:, :]
132+
133+
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
134+
# phase = torch.atan2(imag_part.data, real_part.data)
135+
136+
return magnitude#, phase
137+
138+
def inverse(self, magnitude, phase):
139+
"""Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
140+
by the ```transform``` function.
141+
142+
Arguments:
143+
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
144+
num_frequencies, num_frames)
145+
phase {tensor} -- Phase of STFT with shape (num_batch,
146+
num_frequencies, num_frames)
147+
148+
Returns:
149+
inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
150+
shape (num_batch, num_samples)
151+
"""
152+
recombine_magnitude_phase = torch.cat(
153+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)
154+
155+
inverse_transform = F.conv_transpose1d(
156+
recombine_magnitude_phase,
157+
self.inverse_basis,
158+
stride=self.hop_length,
159+
padding=0)
160+
161+
if self.window is not None:
162+
window_sum = window_sumsquare(
163+
self.window, magnitude.size(-1), hop_length=self.hop_length,
164+
win_length=self.win_length, n_fft=self.filter_length,
165+
dtype=np.float32)
166+
# remove modulation effects
167+
approx_nonzero_indices = torch.from_numpy(
168+
np.where(window_sum > tiny(window_sum))[0])
169+
window_sum = torch.from_numpy(window_sum).to(inverse_transform.device)
170+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
171+
172+
# scale by hop ratio
173+
inverse_transform *= float(self.filter_length) / self.hop_length
174+
175+
inverse_transform = inverse_transform[..., self.pad_amount:]
176+
inverse_transform = inverse_transform[..., :self.num_samples]
177+
inverse_transform = inverse_transform.squeeze(1)
178+
179+
return inverse_transform
180+
181+
def forward(self, input_data):
182+
"""Take input data (audio) to STFT domain and then back to audio.
183+
184+
Arguments:
185+
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
186+
187+
Returns:
188+
reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
189+
shape (num_batch, num_samples)
190+
"""
191+
self.magnitude, self.phase = self.transform(input_data)
192+
reconstruction = self.inverse(self.magnitude, self.phase)
193+
return reconstruction
194+
from time import time as ttime
6195
class BiGRU(nn.Module):
7196
def __init__(self, input_features, hidden_features, num_layers):
8197
super(BiGRU, self).__init__()
@@ -250,9 +439,11 @@ def __init__(
250439
)
251440

252441
def forward(self, mel):
442+
# print(mel.shape)
253443
mel = mel.transpose(-1, -2).unsqueeze(1)
254444
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
255445
x = self.fc(x)
446+
# print(x.shape)
256447
return x
257448

258449

@@ -301,18 +492,33 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
301492
keyshift_key = str(keyshift) + "_" + str(audio.device)
302493
if keyshift_key not in self.hann_window:
303494
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
495+
# "cpu"if(audio.device.type=="privateuseone") else audio.device
304496
audio.device
305497
)
306-
fft = torch.stft(
307-
audio,
308-
n_fft=n_fft_new,
309-
hop_length=hop_length_new,
310-
win_length=win_length_new,
311-
window=self.hann_window[keyshift_key],
312-
center=center,
313-
return_complex=True,
314-
)
315-
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
498+
# fft = torch.stft(#doesn't support pytorch_dml
499+
# # audio.cpu() if(audio.device.type=="privateuseone")else audio,
500+
# audio,
501+
# n_fft=n_fft_new,
502+
# hop_length=hop_length_new,
503+
# win_length=win_length_new,
504+
# window=self.hann_window[keyshift_key],
505+
# center=center,
506+
# return_complex=True,
507+
# )
508+
# magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
509+
# print(1111111111)
510+
# print(222222222222222,audio.device,self.is_half)
511+
if hasattr(self, "stft") == False:
512+
# print(n_fft_new,hop_length_new,win_length_new,audio.shape)
513+
self.stft=STFT(
514+
filter_length=n_fft_new,
515+
hop_length=hop_length_new,
516+
win_length=win_length_new,
517+
window='hann'
518+
).to(audio.device)
519+
magnitude = self.stft.transform(audio)#phase
520+
# if (audio.device.type == "privateuseone"):
521+
# magnitude=magnitude.to(audio.device)
316522
if keyshift != 0:
317523
size = self.n_fft // 2 + 1
318524
resize = magnitude.size(1)
@@ -323,19 +529,13 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
323529
if self.is_half == True:
324530
mel_output = mel_output.half()
325531
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
532+
# print(log_mel_spec.device.type)
326533
return log_mel_spec
327534

328535

329536
class RMVPE:
330537
def __init__(self, model_path, is_half, device=None):
331538
self.resample_kernel = {}
332-
model = E2E(4, 1, (2, 2))
333-
ckpt = torch.load(model_path, map_location="cpu")
334-
model.load_state_dict(ckpt)
335-
model.eval()
336-
if is_half == True:
337-
model = model.half()
338-
self.model = model
339539
self.resample_kernel = {}
340540
self.is_half = is_half
341541
if device is None:
@@ -344,7 +544,19 @@ def __init__(self, model_path, is_half, device=None):
344544
self.mel_extractor = MelSpectrogram(
345545
is_half, 128, 16000, 1024, 160, None, 30, 8000
346546
).to(device)
347-
self.model = self.model.to(device)
547+
if ("privateuseone" in str(device)):
548+
import onnxruntime as ort
549+
ort_session = ort.InferenceSession("rmvpe.onnx", providers=["DmlExecutionProvider"])
550+
self.model=ort_session
551+
else:
552+
model = E2E(4, 1, (2, 2))
553+
ckpt = torch.load(model_path, map_location="cpu")
554+
model.load_state_dict(ckpt)
555+
model.eval()
556+
if is_half == True:
557+
model = model.half()
558+
self.model = model
559+
self.model = self.model.to(device)
348560
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
349561
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
350562

@@ -354,7 +566,12 @@ def mel2hidden(self, mel):
354566
mel = F.pad(
355567
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
356568
)
357-
hidden = self.model(mel)
569+
if("privateuseone" in str(self.device) ):
570+
onnx_input_name = self.model.get_inputs()[0].name
571+
onnx_outputs_names = self.model.get_outputs()[0].name
572+
hidden = self.model.run([onnx_outputs_names], input_feed={onnx_input_name: mel.cpu().numpy()})[0]
573+
else:
574+
hidden = self.model(mel)
358575
return hidden[:, :n_frames]
359576

360577
def decode(self, hidden, thred=0.03):
@@ -365,21 +582,26 @@ def decode(self, hidden, thred=0.03):
365582
return f0
366583

367584
def infer_from_audio(self, audio, thred=0.03):
368-
audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
369585
# torch.cuda.synchronize()
370-
# t0=ttime()
371-
mel = self.mel_extractor(audio, center=True)
586+
t0=ttime()
587+
mel = self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True)
588+
# print(123123123,mel.device.type)
372589
# torch.cuda.synchronize()
373-
# t1=ttime()
590+
t1=ttime()
374591
hidden = self.mel2hidden(mel)
375592
# torch.cuda.synchronize()
376-
# t2=ttime()
377-
hidden = hidden.squeeze(0).cpu().numpy()
593+
t2=ttime()
594+
# print(234234,hidden.device.type)
595+
if("privateuseone" not in str(self.device)):
596+
hidden = hidden.squeeze(0).cpu().numpy()
597+
else:
598+
hidden=hidden[0]
378599
if self.is_half == True:
379600
hidden = hidden.astype("float32")
601+
380602
f0 = self.decode(hidden, thred=thred)
381603
# torch.cuda.synchronize()
382-
# t3=ttime()
604+
t3=ttime()
383605
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
384606
return f0
385607

@@ -410,22 +632,23 @@ def to_local_average_cents(self, salience, thred=0.05):
410632
return devided
411633

412634

413-
# if __name__ == '__main__':
414-
# audio, sampling_rate = sf.read("卢本伟语录~1.wav")
415-
# if len(audio.shape) > 1:
416-
# audio = librosa.to_mono(audio.transpose(1, 0))
417-
# audio_bak = audio.copy()
418-
# if sampling_rate != 16000:
419-
# audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
420-
# model_path = "/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/test-RMVPE/weights/rmvpe_llc_half.pt"
421-
# thred = 0.03 # 0.01
422-
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
423-
# rmvpe = RMVPE(model_path,is_half=False, device=device)
424-
# t0=ttime()
425-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
426-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
427-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
428-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
429-
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
430-
# t1=ttime()
431-
# print(f0.shape,t1-t0)
635+
if __name__ == '__main__':
636+
import soundfile as sf, librosa
637+
audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav")
638+
if len(audio.shape) > 1:
639+
audio = librosa.to_mono(audio.transpose(1, 0))
640+
audio_bak = audio.copy()
641+
if sampling_rate != 16000:
642+
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
643+
model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
644+
thred = 0.03 # 0.01
645+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
646+
rmvpe = RMVPE(model_path,is_half=False, device=device)
647+
t0=ttime()
648+
f0 = rmvpe.infer_from_audio(audio, thred=thred)
649+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
650+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
651+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
652+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
653+
t1=ttime()
654+
print(f0.shape,t1-t0)

0 commit comments

Comments
 (0)