1
1
from io import BytesIO
2
2
import os
3
- from typing import List , Optional , Tuple
3
+ from typing import List , Optional , Tuple , Union
4
4
import numpy as np
5
5
import torch
6
6
25
25
26
26
logger = logging .getLogger (__name__ )
27
27
28
-
29
- class STFT (torch .nn .Module ):
30
- def __init__ (
31
- self , filter_length = 1024 , hop_length = 512 , win_length = None , window = "hann"
32
- ):
33
- """
34
- This module implements an STFT using 1D convolution and 1D transpose convolutions.
35
- This is a bit tricky so there are some cases that probably won't work as working
36
- out the same sizes before and after in all overlap add setups is tough. Right now,
37
- this code should work with hop lengths that are half the filter length (50% overlap
38
- between frames).
39
-
40
- Keyword Arguments:
41
- filter_length {int} -- Length of filters used (default: {1024})
42
- hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
43
- win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
44
- equals the filter length). (default: {None})
45
- window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
46
- (default: {'hann'})
47
- """
48
- super (STFT , self ).__init__ ()
49
- self .filter_length = filter_length
50
- self .hop_length = hop_length
51
- self .win_length = win_length if win_length else filter_length
52
- self .window = window
53
- self .forward_transform = None
54
- self .pad_amount = int (self .filter_length / 2 )
55
- fourier_basis = np .fft .fft (np .eye (self .filter_length ))
56
-
57
- cutoff = int ((self .filter_length / 2 + 1 ))
58
- fourier_basis = np .vstack (
59
- [np .real (fourier_basis [:cutoff , :]), np .imag (fourier_basis [:cutoff , :])]
60
- )
61
- forward_basis = torch .FloatTensor (fourier_basis )
62
- inverse_basis = torch .FloatTensor (np .linalg .pinv (fourier_basis ))
63
-
64
- assert filter_length >= self .win_length
65
- # get window and zero center pad it to filter_length
66
- fft_window = get_window (window , self .win_length , fftbins = True )
67
- fft_window = pad_center (fft_window , size = filter_length )
68
- fft_window = torch .from_numpy (fft_window ).float ()
69
-
70
- # window the bases
71
- forward_basis *= fft_window
72
- inverse_basis = (inverse_basis .T * fft_window ).T
73
-
74
- self .register_buffer ("forward_basis" , forward_basis .float ())
75
- self .register_buffer ("inverse_basis" , inverse_basis .float ())
76
- self .register_buffer ("fft_window" , fft_window .float ())
77
-
78
- def transform (self , input_data , return_phase = False ):
79
- """Take input data (audio) to STFT domain.
80
-
81
- Arguments:
82
- input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
83
-
84
- Returns:
85
- magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
86
- num_frequencies, num_frames)
87
- phase {tensor} -- Phase of STFT with shape (num_batch,
88
- num_frequencies, num_frames)
89
- """
90
- input_data = F .pad (
91
- input_data ,
92
- (self .pad_amount , self .pad_amount ),
93
- mode = "reflect" ,
94
- )
95
- forward_transform = input_data .unfold (
96
- 1 , self .filter_length , self .hop_length
97
- ).permute (0 , 2 , 1 )
98
- forward_transform = torch .matmul (self .forward_basis , forward_transform )
99
- cutoff = int ((self .filter_length / 2 ) + 1 )
100
- real_part = forward_transform [:, :cutoff , :]
101
- imag_part = forward_transform [:, cutoff :, :]
102
- magnitude = torch .sqrt (real_part ** 2 + imag_part ** 2 )
103
- if return_phase :
104
- phase = torch .atan2 (imag_part .data , real_part .data )
105
- return magnitude , phase
106
- else :
107
- return magnitude
108
-
109
- def inverse (self , magnitude , phase ):
110
- """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
111
- by the ```transform``` function.
112
-
113
- Arguments:
114
- magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
115
- num_frequencies, num_frames)
116
- phase {tensor} -- Phase of STFT with shape (num_batch,
117
- num_frequencies, num_frames)
118
-
119
- Returns:
120
- inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
121
- shape (num_batch, num_samples)
122
- """
123
- cat = torch .cat (
124
- [magnitude * torch .cos (phase ), magnitude * torch .sin (phase )], dim = 1
125
- )
126
- fold = torch .nn .Fold (
127
- output_size = (1 , (cat .size (- 1 ) - 1 ) * self .hop_length + self .filter_length ),
128
- kernel_size = (1 , self .filter_length ),
129
- stride = (1 , self .hop_length ),
130
- )
131
- inverse_transform = torch .matmul (self .inverse_basis , cat )
132
- inverse_transform = fold (inverse_transform )[
133
- :, 0 , 0 , self .pad_amount : - self .pad_amount
134
- ]
135
- window_square_sum = (
136
- self .fft_window .pow (2 ).repeat (cat .size (- 1 ), 1 ).T .unsqueeze (0 )
137
- )
138
- window_square_sum = fold (window_square_sum )[
139
- :, 0 , 0 , self .pad_amount : - self .pad_amount
140
- ]
141
- inverse_transform /= window_square_sum
142
- return inverse_transform
143
-
144
- def forward (self , input_data ):
145
- """Take input data (audio) to STFT domain and then back to audio.
146
-
147
- Arguments:
148
- input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
149
-
150
- Returns:
151
- reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
152
- shape (num_batch, num_samples)
153
- """
154
- self .magnitude , self .phase = self .transform (input_data , return_phase = True )
155
- reconstruction = self .inverse (self .magnitude , self .phase )
156
- return reconstruction
157
-
28
+ from rvc .f0 .mel import MelSpectrogram
158
29
159
30
from time import time as ttime
160
31
@@ -412,86 +283,6 @@ def forward(self, mel):
412
283
return x
413
284
414
285
415
- from librosa .filters import mel
416
-
417
-
418
- class MelSpectrogram (torch .nn .Module ):
419
- def __init__ (
420
- self ,
421
- is_half ,
422
- n_mel_channels ,
423
- sampling_rate ,
424
- win_length ,
425
- hop_length ,
426
- n_fft = None ,
427
- mel_fmin = 0 ,
428
- mel_fmax = None ,
429
- clamp = 1e-5 ,
430
- ):
431
- super ().__init__ ()
432
- n_fft = win_length if n_fft is None else n_fft
433
- self .hann_window = {}
434
- mel_basis = mel (
435
- sr = sampling_rate ,
436
- n_fft = n_fft ,
437
- n_mels = n_mel_channels ,
438
- fmin = mel_fmin ,
439
- fmax = mel_fmax ,
440
- htk = True ,
441
- )
442
- mel_basis = torch .from_numpy (mel_basis ).float ()
443
- self .register_buffer ("mel_basis" , mel_basis )
444
- self .n_fft = win_length if n_fft is None else n_fft
445
- self .hop_length = hop_length
446
- self .win_length = win_length
447
- self .sampling_rate = sampling_rate
448
- self .n_mel_channels = n_mel_channels
449
- self .clamp = clamp
450
- self .is_half = is_half
451
-
452
- def forward (self , audio , keyshift = 0 , speed = 1 , center = True ):
453
- factor = 2 ** (keyshift / 12 )
454
- n_fft_new = int (np .round (self .n_fft * factor ))
455
- win_length_new = int (np .round (self .win_length * factor ))
456
- hop_length_new = int (np .round (self .hop_length * speed ))
457
- keyshift_key = str (keyshift ) + "_" + str (audio .device )
458
- if keyshift_key not in self .hann_window :
459
- self .hann_window [keyshift_key ] = torch .hann_window (win_length_new ).to (
460
- audio .device
461
- )
462
- if "privateuseone" in str (audio .device ):
463
- if not hasattr (self , "stft" ):
464
- self .stft = STFT (
465
- filter_length = n_fft_new ,
466
- hop_length = hop_length_new ,
467
- win_length = win_length_new ,
468
- window = "hann" ,
469
- ).to (audio .device )
470
- magnitude = self .stft .transform (audio )
471
- else :
472
- fft = torch .stft (
473
- audio ,
474
- n_fft = n_fft_new ,
475
- hop_length = hop_length_new ,
476
- win_length = win_length_new ,
477
- window = self .hann_window [keyshift_key ],
478
- center = center ,
479
- return_complex = True ,
480
- )
481
- magnitude = torch .sqrt (fft .real .pow (2 ) + fft .imag .pow (2 ))
482
- if keyshift != 0 :
483
- size = self .n_fft // 2 + 1
484
- resize = magnitude .size (1 )
485
- if resize < size :
486
- magnitude = F .pad (magnitude , (0 , 0 , 0 , size - resize ))
487
- magnitude = magnitude [:, :size , :] * self .win_length / win_length_new
488
- mel_output = torch .matmul (self .mel_basis , magnitude )
489
- if self .is_half == True :
490
- mel_output = mel_output .half ()
491
- log_mel_spec = torch .log (torch .clamp (mel_output , min = self .clamp ))
492
- return log_mel_spec
493
-
494
-
495
286
class RMVPE :
496
287
def __init__ (self , model_path : str , is_half , device = None , use_jit = False ):
497
288
self .resample_kernel = {}
@@ -501,7 +292,14 @@ def __init__(self, model_path: str, is_half, device=None, use_jit=False):
501
292
device = "cuda:0" if torch .cuda .is_available () else "cpu"
502
293
self .device = device
503
294
self .mel_extractor = MelSpectrogram (
504
- is_half , 128 , 16000 , 1024 , 160 , None , 30 , 8000
295
+ is_half = is_half ,
296
+ n_mel_channels = 128 ,
297
+ sampling_rate = 16000 ,
298
+ win_length = 1024 ,
299
+ hop_length = 160 ,
300
+ mel_fmin = 30 ,
301
+ mel_fmax = 8000 ,
302
+ device = device ,
505
303
).to (device )
506
304
if "privateuseone" in str (device ):
507
305
import onnxruntime as ort
0 commit comments