1
+ import soundfile as sf
2
+ import torch ,pdb ,time ,argparse ,os ,warnings ,sys ,librosa
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from scipy .io .wavfile import write
6
+ from tqdm import tqdm
7
+ import torch
8
+ import torch .nn as nn
9
+
10
+ dim_c = 4
11
+ class Conv_TDF_net_trim ():
12
+ def __init__ (self , device , model_name , target_name ,
13
+ L , dim_f , dim_t , n_fft , hop = 1024 ):
14
+ super (Conv_TDF_net_trim , self ).__init__ ()
15
+
16
+ self .dim_f = dim_f
17
+ self .dim_t = 2 ** dim_t
18
+ self .n_fft = n_fft
19
+ self .hop = hop
20
+ self .n_bins = self .n_fft // 2 + 1
21
+ self .chunk_size = hop * (self .dim_t - 1 )
22
+ self .window = torch .hann_window (window_length = self .n_fft , periodic = True ).to (device )
23
+ self .target_name = target_name
24
+ self .blender = 'blender' in model_name
25
+
26
+ out_c = dim_c * 4 if target_name == '*' else dim_c
27
+ self .freq_pad = torch .zeros ([1 , out_c , self .n_bins - self .dim_f , self .dim_t ]).to (device )
28
+
29
+ self .n = L // 2
30
+
31
+ def stft (self , x ):
32
+ x = x .reshape ([- 1 , self .chunk_size ])
33
+ x = torch .stft (x , n_fft = self .n_fft , hop_length = self .hop , window = self .window , center = True , return_complex = True )
34
+ x = torch .view_as_real (x )
35
+ x = x .permute ([0 , 3 , 1 , 2 ])
36
+ x = x .reshape ([- 1 , 2 , 2 , self .n_bins , self .dim_t ]).reshape ([- 1 , dim_c , self .n_bins , self .dim_t ])
37
+ return x [:, :, :self .dim_f ]
38
+
39
+ def istft (self , x , freq_pad = None ):
40
+ freq_pad = self .freq_pad .repeat ([x .shape [0 ], 1 , 1 , 1 ]) if freq_pad is None else freq_pad
41
+ x = torch .cat ([x , freq_pad ], - 2 )
42
+ c = 4 * 2 if self .target_name == '*' else 2
43
+ x = x .reshape ([- 1 , c , 2 , self .n_bins , self .dim_t ]).reshape ([- 1 , 2 , self .n_bins , self .dim_t ])
44
+ x = x .permute ([0 , 2 , 3 , 1 ])
45
+ x = x .contiguous ()
46
+ x = torch .view_as_complex (x )
47
+ x = torch .istft (x , n_fft = self .n_fft , hop_length = self .hop , window = self .window , center = True )
48
+ return x .reshape ([- 1 , c , self .chunk_size ])
49
+ def get_models (device , dim_f , dim_t , n_fft ):
50
+ return Conv_TDF_net_trim (
51
+ device = device ,
52
+ model_name = 'Conv-TDF' , target_name = 'vocals' ,
53
+ L = 11 ,
54
+ dim_f = dim_f , dim_t = dim_t ,
55
+ n_fft = n_fft
56
+ )
57
+
58
+ warnings .filterwarnings ("ignore" )
59
+ cpu = torch .device ('cpu' )
60
+ device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
61
+
62
+ class Predictor :
63
+ def __init__ (self ,args ):
64
+ self .args = args
65
+ self .model_ = get_models (device = cpu , dim_f = args .dim_f , dim_t = args .dim_t , n_fft = args .n_fft )
66
+ self .model = ort .InferenceSession (os .path .join (args .onnx ,self .model_ .target_name + '.onnx' ), providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
67
+ print ('onnx load done' )
68
+ def demix (self , mix ):
69
+ samples = mix .shape [- 1 ]
70
+ margin = self .args .margin
71
+ chunk_size = self .args .chunks * 44100
72
+ assert not margin == 0 , 'margin cannot be zero!'
73
+ if margin > chunk_size :
74
+ margin = chunk_size
75
+
76
+ segmented_mix = {}
77
+
78
+ if self .args .chunks == 0 or samples < chunk_size :
79
+ chunk_size = samples
80
+
81
+ counter = - 1
82
+ for skip in range (0 , samples , chunk_size ):
83
+ counter += 1
84
+
85
+ s_margin = 0 if counter == 0 else margin
86
+ end = min (skip + chunk_size + margin , samples )
87
+
88
+ start = skip - s_margin
89
+
90
+ segmented_mix [skip ] = mix [:,start :end ].copy ()
91
+ if end == samples :
92
+ break
93
+
94
+ sources = self .demix_base (segmented_mix , margin_size = margin )
95
+ '''
96
+ mix:(2,big_sample)
97
+ segmented_mix:offset->(2,small_sample)
98
+ sources:(1,2,big_sample)
99
+ '''
100
+ return sources
101
+ def demix_base (self , mixes , margin_size ):
102
+ chunked_sources = []
103
+ progress_bar = tqdm (total = len (mixes ))
104
+ progress_bar .set_description ("Processing" )
105
+ for mix in mixes :
106
+ cmix = mixes [mix ]
107
+ sources = []
108
+ n_sample = cmix .shape [1 ]
109
+ model = self .model_
110
+ trim = model .n_fft // 2
111
+ gen_size = model .chunk_size - 2 * trim
112
+ pad = gen_size - n_sample % gen_size
113
+ mix_p = np .concatenate ((np .zeros ((2 ,trim )), cmix , np .zeros ((2 ,pad )), np .zeros ((2 ,trim ))), 1 )
114
+ mix_waves = []
115
+ i = 0
116
+ while i < n_sample + pad :
117
+ waves = np .array (mix_p [:, i :i + model .chunk_size ])
118
+ mix_waves .append (waves )
119
+ i += gen_size
120
+ mix_waves = torch .tensor (mix_waves , dtype = torch .float32 ).to (cpu )
121
+ with torch .no_grad ():
122
+ _ort = self .model
123
+ spek = model .stft (mix_waves )
124
+ if self .args .denoise :
125
+ spec_pred = - _ort .run (None , {'input' : - spek .cpu ().numpy ()})[0 ]* 0.5 + _ort .run (None , {'input' : spek .cpu ().numpy ()})[0 ]* 0.5
126
+ tar_waves = model .istft (torch .tensor (spec_pred ))
127
+ else :
128
+ tar_waves = model .istft (torch .tensor (_ort .run (None , {'input' : spek .cpu ().numpy ()})[0 ]))
129
+ tar_signal = tar_waves [:,:,trim :- trim ].transpose (0 ,1 ).reshape (2 , - 1 ).numpy ()[:, :- pad ]
130
+
131
+ start = 0 if mix == 0 else margin_size
132
+ end = None if mix == list (mixes .keys ())[::- 1 ][0 ] else - margin_size
133
+ if margin_size == 0 :
134
+ end = None
135
+ sources .append (tar_signal [:,start :end ])
136
+
137
+ progress_bar .update (1 )
138
+
139
+ chunked_sources .append (sources )
140
+ _sources = np .concatenate (chunked_sources , axis = - 1 )
141
+ # del self.model
142
+ progress_bar .close ()
143
+ return _sources
144
+ def prediction (self , m ,vocal_root ,others_root ):
145
+ os .makedirs (vocal_root ,exist_ok = True )
146
+ os .makedirs (others_root ,exist_ok = True )
147
+ basename = os .path .basename (m )
148
+ mix , rate = librosa .load (m , mono = False , sr = 44100 )
149
+ if mix .ndim == 1 :
150
+ mix = np .asfortranarray ([mix ,mix ])
151
+ mix = mix .T
152
+ sources = self .demix (mix .T )
153
+ opt = sources [0 ].T
154
+ sf .write ("%s/%s_main_vocal.wav" % (vocal_root ,basename ), mix - opt , rate )
155
+ sf .write ("%s/%s_others.wav" % (others_root ,basename ), opt , rate )
156
+
157
+ class MDXNetDereverb ():
158
+ def __init__ (self ,chunks ):
159
+ self .onnx = "uvr5_weights/onnx_dereverb_By_FoxJoy"
160
+ self .shifts = 10 #'Predict with randomised equivariant stabilisation'
161
+ self .mixing = "min_mag" #['default','min_mag','max_mag']
162
+ self .chunks = chunks
163
+ self .margin = 44100
164
+ self .dim_t = 9
165
+ self .dim_f = 3072
166
+ self .n_fft = 6144
167
+ self .denoise = True
168
+ self .pred = Predictor (self )
169
+
170
+ def _path_audio_ (self ,input ,vocal_root ,others_root ):
171
+ self .pred .prediction (input ,vocal_root ,others_root )
172
+
173
+ if __name__ == '__main__' :
174
+ dereverb = MDXNetDereverb (15 )
175
+ from time import time as ttime
176
+ t0 = ttime ()
177
+ dereverb ._path_audio_ (
178
+ "雪雪伴奏对消HP5.wav" ,
179
+ "vocal" ,
180
+ "others" ,
181
+ )
182
+ t1 = ttime ()
183
+ print (t1 - t0 )
184
+
185
+
186
+ '''
187
+
188
+ runtime\python.exe MDXNet.py
189
+
190
+ 6G:
191
+ 15/9:0.8G->6.8G
192
+ 14:0.8G->6.5G
193
+ 25:炸
194
+
195
+ half15:0.7G->6.6G,22.69s
196
+ fp32-15:0.7G->6.6G,20.85s
197
+
198
+ '''
0 commit comments