Skip to content

Commit e298fde

Browse files
committed
optimize(crepe): move crepe into rvc.f0
1 parent f79b925 commit e298fde

File tree

7 files changed

+106
-54
lines changed

7 files changed

+106
-54
lines changed

infer/lib/rtrvc.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -292,22 +292,16 @@ def get_f0_crepe(self, x, f0_up_key):
292292
self.device
293293
): ###不支持dml,cpu又太慢用不成,拿fcpe顶替
294294
return self.get_f0(x, f0_up_key, 1, "fcpe")
295-
# printt("using crepe,device:%s"%self.device)
296-
f0, pd = torchcrepe.predict(
297-
x.unsqueeze(0).float(),
298-
16000,
299-
160,
300-
self.f0_min,
301-
self.f0_max,
302-
"full",
303-
batch_size=512,
304-
# device=self.device if self.device.type!="privateuseone" else "cpu",###crepe不用半精度全部是全精度所以不愁###cpu延迟高到没法用
305-
device=self.device,
306-
return_periodicity=True,
307-
)
308-
pd = torchcrepe.filter.median(pd, 3)
309-
f0 = torchcrepe.filter.mean(f0, 3)
310-
f0[pd < 0.1] = 0
295+
if hasattr(self, "model_crepe") == False:
296+
from rvc.f0 import CRePE
297+
self.model_crepe = CRePE(
298+
160,
299+
self.f0_min,
300+
self.f0_max,
301+
16000,
302+
self.device,
303+
)
304+
f0 = self.model_crepe.compute_f0(x)
311305
f0 *= pow(2, f0_up_key / 12)
312306
return self.get_f0_post(f0)
313307

infer/modules/vc/pipeline.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
import numpy as np
1313
import torch
1414
import torch.nn.functional as F
15-
import torchcrepe
1615
from scipy import signal
1716

18-
from rvc.f0 import PM, Harvest, RMVPE
17+
from rvc.f0 import PM, Harvest, RMVPE, CRePE, Dio
1918

2019
now_dir = os.getcwd()
2120
sys.path.append(now_dir)
@@ -81,31 +80,24 @@ def get_f0(
8180
if not hasattr(self, "pm"):
8281
self.pm = PM(self.window, f0_min, f0_max, self.sr)
8382
f0 = self.pm.compute_f0(x, p_len=p_len)
83+
if f0_method == "dio":
84+
if not hasattr(self, "dio"):
85+
self.dio = Dio(self.window, f0_min, f0_max, self.sr)
86+
f0 = self.dio.compute_f0(x, p_len=p_len)
8487
elif f0_method == "harvest":
8588
if not hasattr(self, "harvest"):
8689
self.harvest = Harvest(self.window, f0_min, f0_max, self.sr)
8790
f0 = self.harvest.compute_f0(x, p_len=p_len, filter_radius=filter_radius)
8891
elif f0_method == "crepe":
89-
model = "full"
90-
# Pick a batch size that doesn't cause memory errors on your gpu
91-
batch_size = 512
92-
# Compute pitch using first gpu
93-
audio = torch.tensor(np.copy(x))[None].float()
94-
f0, pd = torchcrepe.predict(
95-
audio,
96-
self.sr,
97-
self.window,
98-
f0_min,
99-
f0_max,
100-
model,
101-
batch_size=batch_size,
102-
device=self.device,
103-
return_periodicity=True,
104-
)
105-
pd = torchcrepe.filter.median(pd, 3)
106-
f0 = torchcrepe.filter.mean(f0, 3)
107-
f0[pd < 0.1] = 0
108-
f0 = f0[0].cpu().numpy()
92+
if not hasattr(self, "crepe"):
93+
self.crepe = CRePE(
94+
self.window,
95+
f0_min,
96+
f0_max,
97+
self.sr,
98+
self.device,
99+
)
100+
f0 = self.crepe.compute_f0(x, p_len=p_len)
109101
elif f0_method == "rmvpe":
110102
if not hasattr(self, "rmvpe"):
111103
logger.info(
@@ -117,7 +109,7 @@ def get_f0(
117109
device=self.device,
118110
# use_jit=self.config.use_jit,
119111
)
120-
f0 = self.rmvpe.compute_f0(x, filter_radius=0.03)
112+
f0 = self.rmvpe.compute_f0(x, p_len=p_len, filter_radius=0.03)
121113

122114
if "privateuseone" in str(self.device): # clean ortruntime memory
123115
del self.rmvpe.model

rvc/f0/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .f0 import F0Predictor
22

3+
from .crepe import CRePE
34
from .dio import Dio
45
from .harvest import Harvest
56
from .pm import PM

rvc/f0/crepe.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Any, Optional, Union
2+
3+
import numpy as np
4+
import torch
5+
import torchcrepe
6+
7+
from .f0 import F0Predictor
8+
9+
10+
class CRePE(F0Predictor):
11+
def __init__(
12+
self,
13+
hop_length=512,
14+
f0_min=50,
15+
f0_max=1100,
16+
sampling_rate=44100,
17+
device="cpu",
18+
):
19+
super().__init__(
20+
hop_length,
21+
f0_min,
22+
f0_max,
23+
sampling_rate,
24+
device,
25+
)
26+
27+
def compute_f0(
28+
self,
29+
wav: np.ndarray[Any, np.dtype],
30+
p_len: Optional[int] = None,
31+
filter_radius: Optional[Union[int, float]] = None,
32+
):
33+
if p_len is None:
34+
p_len = wav.shape[0] // self.hop_length
35+
# Pick a batch size that doesn't cause memory errors on your gpu
36+
batch_size = 512
37+
# Compute pitch using device 'device'
38+
f0, pd = torchcrepe.predict(
39+
torch.tensor(np.copy(wav))[None].float().to(self.device),
40+
self.sampling_rate,
41+
self.hop_length,
42+
self.f0_min,
43+
self.f0_max,
44+
batch_size=batch_size,
45+
device=self.device,
46+
return_periodicity=True,
47+
)
48+
pd = torchcrepe.filter.median(pd, 3)
49+
f0 = torchcrepe.filter.mean(f0, 3)
50+
f0[pd < 0.1] = 0
51+
f0 = f0[0].cpu().numpy()
52+
return self._interpolate_f0(self._resize_f0(f0, p_len))[0]

rvc/f0/f0.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
from typing import Any, Optional, Union
22

3+
import torch
34
import numpy as np
45

56

67
class F0Predictor(object):
7-
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
8+
def __init__(
9+
self,
10+
hop_length=512,
11+
f0_min=50,
12+
f0_max=1100,
13+
sampling_rate=44100,
14+
device: Optional[str] = None,
15+
):
816
self.hop_length = hop_length
917
self.f0_min = f0_min
1018
self.f0_max = f0_max
1119
self.sampling_rate = sampling_rate
20+
if device is None:
21+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
22+
self.device = device
1223

1324
def compute_f0(
1425
self,

rvc/f0/rmvpe.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@ def __init__(
2626
f0_max = 8000
2727
sampling_rate = 16000
2828

29-
super().__init__(hop_length, f0_min, f0_max, sampling_rate)
29+
super().__init__(
30+
hop_length,
31+
f0_min,
32+
f0_max,
33+
sampling_rate,
34+
device,
35+
)
3036

3137
self.is_half = is_half
3238
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
3339
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
3440

35-
if device is None:
36-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
37-
self.device = device
38-
3941
self.mel_extractor = MelSpectrogram(
4042
is_half=is_half,
4143
n_mel_channels=128,
@@ -44,10 +46,10 @@ def __init__(
4446
hop_length=hop_length,
4547
mel_fmin=f0_min,
4648
mel_fmax=f0_max,
47-
device=device,
48-
).to(device)
49+
device=self.device,
50+
).to(self.device)
4951

50-
if "privateuseone" in str(device):
52+
if "privateuseone" in str(self.device):
5153
import onnxruntime as ort
5254

5355
self.model = ort.InferenceSession(
@@ -73,11 +75,11 @@ def get_jit_model():
7375
mode="script",
7476
inputs_path=None,
7577
save_path=jit_model_path,
76-
device=device,
78+
device=self.device,
7779
is_half=is_half,
7880
)
7981

80-
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device)
82+
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
8183
return model
8284

8385
def get_default_model():
@@ -99,7 +101,7 @@ def get_default_model():
99101
else:
100102
self.model = get_default_model()
101103

102-
self.model = self.model.to(device)
104+
self.model = self.model.to(self.device)
103105

104106
def compute_f0(
105107
self,

web.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -861,9 +861,9 @@ def change_f0_method(f0method8):
861861
"Select the pitch extraction algorithm ('pm': faster extraction but lower-quality speech; 'harvest': better bass but extremely slow; 'crepe': better quality but GPU intensive), 'rmvpe': best quality, and little GPU requirement"
862862
),
863863
choices=(
864-
["pm", "harvest", "crepe", "rmvpe"]
865-
if config.dml == False
866-
else ["pm", "harvest", "rmvpe"]
864+
["pm", "dio", "harvest", "rmvpe"]
865+
if config.dml
866+
else ["pm", "dio", "harvest", "crepe", "rmvpe"]
867867
),
868868
value="rmvpe",
869869
interactive=True,

0 commit comments

Comments
 (0)