@@ -61,9 +61,7 @@ def forward_dml(ctx, x, scale):
61
61
self .cache_pitch : torch .Tensor = torch .zeros (
62
62
1024 , device = self .device , dtype = torch .long
63
63
)
64
- self .cache_pitchf = torch .zeros (
65
- 1024 , device = self .device , dtype = torch .float32
66
- )
64
+ self .cache_pitchf = torch .zeros (1024 , device = self .device , dtype = torch .float32 )
67
65
68
66
self .resample_kernel = {}
69
67
@@ -111,13 +109,15 @@ def set_jit_model():
111
109
self .tgt_sr = cpt ["config" ][- 1 ]
112
110
self .if_f0 = cpt .get ("f0" , 1 )
113
111
self .version = cpt .get ("version" , "v1" )
114
- self .net_g = torch .jit .load (
115
- BytesIO (cpt ["model" ]), map_location = self .device
116
- )
112
+ self .net_g = torch .jit .load (BytesIO (cpt ["model" ]), map_location = self .device )
117
113
self .net_g .infer = self .net_g .forward
118
114
self .net_g .eval ().to (self .device )
119
115
120
- if self .use_jit and not is_dml and not (self .is_half and "cpu" in str (self .device )):
116
+ if (
117
+ self .use_jit
118
+ and not is_dml
119
+ and not (self .is_half and "cpu" in str (self .device ))
120
+ ):
121
121
set_jit_model ()
122
122
else :
123
123
set_default_model ()
@@ -202,9 +202,13 @@ def infer(
202
202
elif self .if_f0 == 1 :
203
203
f0_extractor_frame = block_frame_16k + 800
204
204
if f0method == "rmvpe" :
205
- f0_extractor_frame = 5120 * ((f0_extractor_frame - 1 ) // 5120 + 1 ) - self .window
205
+ f0_extractor_frame = (
206
+ 5120 * ((f0_extractor_frame - 1 ) // 5120 + 1 ) - self .window
207
+ )
206
208
if inp_f0 is not None :
207
- pitch , pitchf = self ._get_f0_post (inp_f0 , self .f0_up_key - self .formant_shift )
209
+ pitch , pitchf = self ._get_f0_post (
210
+ inp_f0 , self .f0_up_key - self .formant_shift
211
+ )
208
212
else :
209
213
pitch , pitchf = self ._get_f0 (
210
214
input_wav [- f0_extractor_frame :],
@@ -272,12 +276,12 @@ def _get_f0(
272
276
x : torch .Tensor ,
273
277
f0_up_key : Union [int , float ],
274
278
filter_radius : Union [int , float ],
275
- method : Literal ["crepe" , "rmvpe" , "fcpe" , "pm" , "harvest" , "dio" ]= "fcpe" ,
279
+ method : Literal ["crepe" , "rmvpe" , "fcpe" , "pm" , "harvest" , "dio" ] = "fcpe" ,
276
280
):
277
281
if method not in self .f0_methods .keys ():
278
- raise RuntimeError ("Not supported f0 method: " + method )
282
+ raise RuntimeError ("Not supported f0 method: " + method )
279
283
return self .f0_methods [method ](x , f0_up_key , filter_radius )
280
-
284
+
281
285
def _get_f0_post (self , f0 , f0_up_key ):
282
286
f0 *= pow (2 , f0_up_key / 12 )
283
287
if not torch .is_tensor (f0 ):
@@ -297,7 +301,7 @@ def _get_f0_pm(self, x, f0_up_key, filter_radius):
297
301
self .pm = PM (hop_length = 160 , sampling_rate = 16000 )
298
302
f0 = self .pm .compute_f0 (x )
299
303
return self ._get_f0_post (f0 , f0_up_key )
300
-
304
+
301
305
def _get_f0_harvest (self , x , f0_up_key , filter_radius ):
302
306
if not hasattr (self , "harvest" ):
303
307
self .harvest = Harvest (
@@ -308,7 +312,7 @@ def _get_f0_harvest(self, x, f0_up_key, filter_radius):
308
312
)
309
313
f0 = self .harvest .compute_f0 (x , filter_radius = filter_radius )
310
314
return self ._get_f0_post (f0 , f0_up_key )
311
-
315
+
312
316
def _get_f0_dio (self , x , f0_up_key , filter_radius ):
313
317
if not hasattr (self , "dio" ):
314
318
self .dio = Dio (
@@ -341,7 +345,8 @@ def _get_f0_rmvpe(self, x, f0_up_key, filter_radius=0.03):
341
345
use_jit = self .use_jit ,
342
346
)
343
347
return self ._get_f0_post (
344
- self .rmvpe .compute_f0 (x , thred = filter_radius ), f0_up_key ,
348
+ self .rmvpe .compute_f0 (x , thred = filter_radius ),
349
+ f0_up_key ,
345
350
)
346
351
347
352
def _get_f0_fcpe (self , x , f0_up_key , filter_radius ):
0 commit comments