1
- import torch , numpy as np
1
+ import torch , numpy as np , pdb
2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
-
5
-
4
+ import torch ,pdb
5
+ import numpy as np
6
+ import torch .nn .functional as F
7
+ from scipy .signal import get_window
8
+ from librosa .util import pad_center , tiny ,normalize
9
+ ###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
10
+ def window_sumsquare (window , n_frames , hop_length = 200 , win_length = 800 ,
11
+ n_fft = 800 , dtype = np .float32 , norm = None ):
12
+ """
13
+ # from librosa 0.6
14
+ Compute the sum-square envelope of a window function at a given hop length.
15
+ This is used to estimate modulation effects induced by windowing
16
+ observations in short-time fourier transforms.
17
+ Parameters
18
+ ----------
19
+ window : string, tuple, number, callable, or list-like
20
+ Window specification, as in `get_window`
21
+ n_frames : int > 0
22
+ The number of analysis frames
23
+ hop_length : int > 0
24
+ The number of samples to advance between frames
25
+ win_length : [optional]
26
+ The length of the window function. By default, this matches `n_fft`.
27
+ n_fft : int > 0
28
+ The length of each analysis frame.
29
+ dtype : np.dtype
30
+ The data type of the output
31
+ Returns
32
+ -------
33
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
34
+ The sum-squared envelope of the window function
35
+ """
36
+ if win_length is None :
37
+ win_length = n_fft
38
+
39
+ n = n_fft + hop_length * (n_frames - 1 )
40
+ x = np .zeros (n , dtype = dtype )
41
+
42
+ # Compute the squared window at the desired length
43
+ win_sq = get_window (window , win_length , fftbins = True )
44
+ win_sq = normalize (win_sq , norm = norm )** 2
45
+ win_sq = pad_center (win_sq , n_fft )
46
+
47
+ # Fill the envelope
48
+ for i in range (n_frames ):
49
+ sample = i * hop_length
50
+ x [sample :min (n , sample + n_fft )] += win_sq [:max (0 , min (n_fft , n - sample ))]
51
+ return x
52
+
53
+ class STFT (torch .nn .Module ):
54
+ def __init__ (self , filter_length = 1024 , hop_length = 512 , win_length = None ,
55
+ window = 'hann' ):
56
+ """
57
+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
58
+ This is a bit tricky so there are some cases that probably won't work as working
59
+ out the same sizes before and after in all overlap add setups is tough. Right now,
60
+ this code should work with hop lengths that are half the filter length (50% overlap
61
+ between frames).
62
+
63
+ Keyword Arguments:
64
+ filter_length {int} -- Length of filters used (default: {1024})
65
+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
66
+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
67
+ equals the filter length). (default: {None})
68
+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
69
+ (default: {'hann'})
70
+ """
71
+ super (STFT , self ).__init__ ()
72
+ self .filter_length = filter_length
73
+ self .hop_length = hop_length
74
+ self .win_length = win_length if win_length else filter_length
75
+ self .window = window
76
+ self .forward_transform = None
77
+ self .pad_amount = int (self .filter_length / 2 )
78
+ scale = self .filter_length / self .hop_length
79
+ fourier_basis = np .fft .fft (np .eye (self .filter_length ))
80
+
81
+ cutoff = int ((self .filter_length / 2 + 1 ))
82
+ fourier_basis = np .vstack ([np .real (fourier_basis [:cutoff , :]),np .imag (fourier_basis [:cutoff , :])])
83
+ forward_basis = torch .FloatTensor (fourier_basis [:, None , :])
84
+ inverse_basis = torch .FloatTensor (
85
+ np .linalg .pinv (scale * fourier_basis ).T [:, None , :])
86
+
87
+ assert (filter_length >= self .win_length )
88
+ # get window and zero center pad it to filter_length
89
+ fft_window = get_window (window , self .win_length , fftbins = True )
90
+ fft_window = pad_center (fft_window , size = filter_length )
91
+ fft_window = torch .from_numpy (fft_window ).float ()
92
+
93
+ # window the bases
94
+ forward_basis *= fft_window
95
+ inverse_basis *= fft_window
96
+
97
+ self .register_buffer ('forward_basis' , forward_basis .float ())
98
+ self .register_buffer ('inverse_basis' , inverse_basis .float ())
99
+
100
+ def transform (self , input_data ):
101
+ """Take input data (audio) to STFT domain.
102
+
103
+ Arguments:
104
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
105
+
106
+ Returns:
107
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
108
+ num_frequencies, num_frames)
109
+ phase {tensor} -- Phase of STFT with shape (num_batch,
110
+ num_frequencies, num_frames)
111
+ """
112
+ num_batches = input_data .shape [0 ]
113
+ num_samples = input_data .shape [- 1 ]
114
+
115
+ self .num_samples = num_samples
116
+
117
+ # similar to librosa, reflect-pad the input
118
+ input_data = input_data .view (num_batches , 1 , num_samples )
119
+ # print(1234,input_data.shape)
120
+ input_data = F .pad (input_data .unsqueeze (1 ),(self .pad_amount , self .pad_amount , 0 , 0 ,0 ,0 ),mode = 'reflect' ).squeeze (1 )
121
+ # print(2333,input_data.shape,self.forward_basis.shape,self.hop_length)
122
+ # pdb.set_trace()
123
+ forward_transform = F .conv1d (
124
+ input_data ,
125
+ self .forward_basis ,
126
+ stride = self .hop_length ,
127
+ padding = 0 )
128
+
129
+ cutoff = int ((self .filter_length / 2 ) + 1 )
130
+ real_part = forward_transform [:, :cutoff , :]
131
+ imag_part = forward_transform [:, cutoff :, :]
132
+
133
+ magnitude = torch .sqrt (real_part ** 2 + imag_part ** 2 )
134
+ # phase = torch.atan2(imag_part.data, real_part.data)
135
+
136
+ return magnitude #, phase
137
+
138
+ def inverse (self , magnitude , phase ):
139
+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
140
+ by the ```transform``` function.
141
+
142
+ Arguments:
143
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
144
+ num_frequencies, num_frames)
145
+ phase {tensor} -- Phase of STFT with shape (num_batch,
146
+ num_frequencies, num_frames)
147
+
148
+ Returns:
149
+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
150
+ shape (num_batch, num_samples)
151
+ """
152
+ recombine_magnitude_phase = torch .cat (
153
+ [magnitude * torch .cos (phase ), magnitude * torch .sin (phase )], dim = 1 )
154
+
155
+ inverse_transform = F .conv_transpose1d (
156
+ recombine_magnitude_phase ,
157
+ self .inverse_basis ,
158
+ stride = self .hop_length ,
159
+ padding = 0 )
160
+
161
+ if self .window is not None :
162
+ window_sum = window_sumsquare (
163
+ self .window , magnitude .size (- 1 ), hop_length = self .hop_length ,
164
+ win_length = self .win_length , n_fft = self .filter_length ,
165
+ dtype = np .float32 )
166
+ # remove modulation effects
167
+ approx_nonzero_indices = torch .from_numpy (
168
+ np .where (window_sum > tiny (window_sum ))[0 ])
169
+ window_sum = torch .from_numpy (window_sum ).to (inverse_transform .device )
170
+ inverse_transform [:, :, approx_nonzero_indices ] /= window_sum [approx_nonzero_indices ]
171
+
172
+ # scale by hop ratio
173
+ inverse_transform *= float (self .filter_length ) / self .hop_length
174
+
175
+ inverse_transform = inverse_transform [..., self .pad_amount :]
176
+ inverse_transform = inverse_transform [..., :self .num_samples ]
177
+ inverse_transform = inverse_transform .squeeze (1 )
178
+
179
+ return inverse_transform
180
+
181
+ def forward (self , input_data ):
182
+ """Take input data (audio) to STFT domain and then back to audio.
183
+
184
+ Arguments:
185
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
186
+
187
+ Returns:
188
+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
189
+ shape (num_batch, num_samples)
190
+ """
191
+ self .magnitude , self .phase = self .transform (input_data )
192
+ reconstruction = self .inverse (self .magnitude , self .phase )
193
+ return reconstruction
194
+ from time import time as ttime
6
195
class BiGRU (nn .Module ):
7
196
def __init__ (self , input_features , hidden_features , num_layers ):
8
197
super (BiGRU , self ).__init__ ()
@@ -250,9 +439,11 @@ def __init__(
250
439
)
251
440
252
441
def forward (self , mel ):
442
+ # print(mel.shape)
253
443
mel = mel .transpose (- 1 , - 2 ).unsqueeze (1 )
254
444
x = self .cnn (self .unet (mel )).transpose (1 , 2 ).flatten (- 2 )
255
445
x = self .fc (x )
446
+ # print(x.shape)
256
447
return x
257
448
258
449
@@ -301,18 +492,33 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
301
492
keyshift_key = str (keyshift ) + "_" + str (audio .device )
302
493
if keyshift_key not in self .hann_window :
303
494
self .hann_window [keyshift_key ] = torch .hann_window (win_length_new ).to (
495
+ # "cpu"if(audio.device.type=="privateuseone") else audio.device
304
496
audio .device
305
497
)
306
- fft = torch .stft (
307
- audio ,
308
- n_fft = n_fft_new ,
309
- hop_length = hop_length_new ,
310
- win_length = win_length_new ,
311
- window = self .hann_window [keyshift_key ],
312
- center = center ,
313
- return_complex = True ,
314
- )
315
- magnitude = torch .sqrt (fft .real .pow (2 ) + fft .imag .pow (2 ))
498
+ # fft = torch.stft(#doesn't support pytorch_dml
499
+ # # audio.cpu() if(audio.device.type=="privateuseone")else audio,
500
+ # audio,
501
+ # n_fft=n_fft_new,
502
+ # hop_length=hop_length_new,
503
+ # win_length=win_length_new,
504
+ # window=self.hann_window[keyshift_key],
505
+ # center=center,
506
+ # return_complex=True,
507
+ # )
508
+ # magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
509
+ # print(1111111111)
510
+ # print(222222222222222,audio.device,self.is_half)
511
+ if hasattr (self , "stft" ) == False :
512
+ # print(n_fft_new,hop_length_new,win_length_new,audio.shape)
513
+ self .stft = STFT (
514
+ filter_length = n_fft_new ,
515
+ hop_length = hop_length_new ,
516
+ win_length = win_length_new ,
517
+ window = 'hann'
518
+ ).to (audio .device )
519
+ magnitude = self .stft .transform (audio )#phase
520
+ # if (audio.device.type == "privateuseone"):
521
+ # magnitude=magnitude.to(audio.device)
316
522
if keyshift != 0 :
317
523
size = self .n_fft // 2 + 1
318
524
resize = magnitude .size (1 )
@@ -323,19 +529,13 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
323
529
if self .is_half == True :
324
530
mel_output = mel_output .half ()
325
531
log_mel_spec = torch .log (torch .clamp (mel_output , min = self .clamp ))
532
+ # print(log_mel_spec.device.type)
326
533
return log_mel_spec
327
534
328
535
329
536
class RMVPE :
330
537
def __init__ (self , model_path , is_half , device = None ):
331
538
self .resample_kernel = {}
332
- model = E2E (4 , 1 , (2 , 2 ))
333
- ckpt = torch .load (model_path , map_location = "cpu" )
334
- model .load_state_dict (ckpt )
335
- model .eval ()
336
- if is_half == True :
337
- model = model .half ()
338
- self .model = model
339
539
self .resample_kernel = {}
340
540
self .is_half = is_half
341
541
if device is None :
@@ -344,7 +544,19 @@ def __init__(self, model_path, is_half, device=None):
344
544
self .mel_extractor = MelSpectrogram (
345
545
is_half , 128 , 16000 , 1024 , 160 , None , 30 , 8000
346
546
).to (device )
347
- self .model = self .model .to (device )
547
+ if ("privateuseone" in str (device )):
548
+ import onnxruntime as ort
549
+ ort_session = ort .InferenceSession ("rmvpe.onnx" , providers = ["DmlExecutionProvider" ])
550
+ self .model = ort_session
551
+ else :
552
+ model = E2E (4 , 1 , (2 , 2 ))
553
+ ckpt = torch .load (model_path , map_location = "cpu" )
554
+ model .load_state_dict (ckpt )
555
+ model .eval ()
556
+ if is_half == True :
557
+ model = model .half ()
558
+ self .model = model
559
+ self .model = self .model .to (device )
348
560
cents_mapping = 20 * np .arange (360 ) + 1997.3794084376191
349
561
self .cents_mapping = np .pad (cents_mapping , (4 , 4 )) # 368
350
562
@@ -354,7 +566,12 @@ def mel2hidden(self, mel):
354
566
mel = F .pad (
355
567
mel , (0 , 32 * ((n_frames - 1 ) // 32 + 1 ) - n_frames ), mode = "reflect"
356
568
)
357
- hidden = self .model (mel )
569
+ if ("privateuseone" in str (self .device ) ):
570
+ onnx_input_name = self .model .get_inputs ()[0 ].name
571
+ onnx_outputs_names = self .model .get_outputs ()[0 ].name
572
+ hidden = self .model .run ([onnx_outputs_names ], input_feed = {onnx_input_name : mel .cpu ().numpy ()})[0 ]
573
+ else :
574
+ hidden = self .model (mel )
358
575
return hidden [:, :n_frames ]
359
576
360
577
def decode (self , hidden , thred = 0.03 ):
@@ -365,21 +582,26 @@ def decode(self, hidden, thred=0.03):
365
582
return f0
366
583
367
584
def infer_from_audio (self , audio , thred = 0.03 ):
368
- audio = torch .from_numpy (audio ).float ().to (self .device ).unsqueeze (0 )
369
585
# torch.cuda.synchronize()
370
- # t0=ttime()
371
- mel = self .mel_extractor (audio , center = True )
586
+ t0 = ttime ()
587
+ mel = self .mel_extractor (torch .from_numpy (audio ).float ().to (self .device ).unsqueeze (0 ), center = True )
588
+ # print(123123123,mel.device.type)
372
589
# torch.cuda.synchronize()
373
- # t1=ttime()
590
+ t1 = ttime ()
374
591
hidden = self .mel2hidden (mel )
375
592
# torch.cuda.synchronize()
376
- # t2=ttime()
377
- hidden = hidden .squeeze (0 ).cpu ().numpy ()
593
+ t2 = ttime ()
594
+ # print(234234,hidden.device.type)
595
+ if ("privateuseone" not in str (self .device )):
596
+ hidden = hidden .squeeze (0 ).cpu ().numpy ()
597
+ else :
598
+ hidden = hidden [0 ]
378
599
if self .is_half == True :
379
600
hidden = hidden .astype ("float32" )
601
+
380
602
f0 = self .decode (hidden , thred = thred )
381
603
# torch.cuda.synchronize()
382
- # t3=ttime()
604
+ t3 = ttime ()
383
605
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
384
606
return f0
385
607
@@ -410,22 +632,23 @@ def to_local_average_cents(self, salience, thred=0.05):
410
632
return devided
411
633
412
634
413
- # if __name__ == '__main__':
414
- # audio, sampling_rate = sf.read("卢本伟语录~1.wav")
415
- # if len(audio.shape) > 1:
416
- # audio = librosa.to_mono(audio.transpose(1, 0))
417
- # audio_bak = audio.copy()
418
- # if sampling_rate != 16000:
419
- # audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
420
- # model_path = "/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/test-RMVPE/weights/rmvpe_llc_half.pt"
421
- # thred = 0.03 # 0.01
422
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
423
- # rmvpe = RMVPE(model_path,is_half=False, device=device)
424
- # t0=ttime()
425
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
426
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
427
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
428
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
429
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
430
- # t1=ttime()
431
- # print(f0.shape,t1-t0)
635
+ if __name__ == '__main__' :
636
+ import soundfile as sf , librosa
637
+ audio , sampling_rate = sf .read (r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav" )
638
+ if len (audio .shape ) > 1 :
639
+ audio = librosa .to_mono (audio .transpose (1 , 0 ))
640
+ audio_bak = audio .copy ()
641
+ if sampling_rate != 16000 :
642
+ audio = librosa .resample (audio , orig_sr = sampling_rate , target_sr = 16000 )
643
+ model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
644
+ thred = 0.03 # 0.01
645
+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
646
+ rmvpe = RMVPE (model_path ,is_half = False , device = device )
647
+ t0 = ttime ()
648
+ f0 = rmvpe .infer_from_audio (audio , thred = thred )
649
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
650
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
651
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
652
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
653
+ t1 = ttime ()
654
+ print (f0 .shape ,t1 - t0 )
0 commit comments