Skip to content

Commit 8ac5597

Browse files
committed
optimize(rmvpe): move rmvpe into rvc.f0
1 parent 77b371d commit 8ac5597

File tree

12 files changed

+94
-93
lines changed

12 files changed

+94
-93
lines changed

infer/lib/rtrvc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def get_f0_crepe(self, x, f0_up_key):
313313

314314
def get_f0_rmvpe(self, x, f0_up_key):
315315
if hasattr(self, "model_rmvpe") == False:
316-
from infer.lib.rmvpe import RMVPE
316+
from rvc.f0 import RMVPE
317317

318318
printt("Loading rmvpe model")
319319
self.model_rmvpe = RMVPE(
@@ -322,7 +322,7 @@ def get_f0_rmvpe(self, x, f0_up_key):
322322
device=self.device,
323323
use_jit=self.config.use_jit,
324324
)
325-
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
325+
f0 = self.model_rmvpe.compute_f0(x, thred=0.03)
326326
f0 *= pow(2, f0_up_key / 12)
327327
return self.get_f0_post(f0)
328328

infer/modules/train/extract/extract_f0_print.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ def compute_f0(self, path, f0_method):
8383
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
8484
elif f0_method == "rmvpe":
8585
if hasattr(self, "model_rmvpe") == False:
86-
from infer.lib.rmvpe import RMVPE
86+
from rvc.f0.rmvpe import RMVPE
8787

8888
print("Loading rmvpe model")
8989
self.model_rmvpe = RMVPE(
9090
"assets/rmvpe/rmvpe.pt", is_half=False, device="cpu"
9191
)
92-
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
92+
f0 = self.model_rmvpe.compute_f0(x, filter_radius=0.03)
9393
return f0
9494

9595
def coarse_f0(self, f0):

infer/modules/train/extract/extract_f0_rmvpe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ def compute_f0(self, path, f0_method):
4646
# p_len = x.shape[0] // self.hop
4747
if f0_method == "rmvpe":
4848
if hasattr(self, "model_rmvpe") == False:
49-
from infer.lib.rmvpe import RMVPE
49+
from rvc.f0.rmvpe import RMVPE
5050

5151
print("Loading rmvpe model")
5252
self.model_rmvpe = RMVPE(
5353
"assets/rmvpe/rmvpe.pt", is_half=is_half, device="cuda"
5454
)
55-
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
55+
f0 = self.model_rmvpe.compute_f0(x, filter_radius=0.03)
5656
return f0
5757

5858
def coarse_f0(self, f0):

infer/modules/train/extract/extract_f0_rmvpe_dml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ def compute_f0(self, path, f0_method):
4444
# p_len = x.shape[0] // self.hop
4545
if f0_method == "rmvpe":
4646
if hasattr(self, "model_rmvpe") == False:
47-
from infer.lib.rmvpe import RMVPE
47+
from rvc.f0.rmvpe import RMVPE
4848

4949
print("Loading rmvpe model")
5050
self.model_rmvpe = RMVPE(
5151
"assets/rmvpe/rmvpe.pt", is_half=False, device=device
5252
)
53-
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
53+
f0 = self.model_rmvpe.compute_f0(x, filter_radius=0.03)
5454
return f0
5555

5656
def coarse_f0(self, f0):

infer/modules/vc/pipeline.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torchcrepe
1717
from scipy import signal
1818

19-
from rvc.f0 import PM, Harvest
19+
from rvc.f0 import PM, Harvest, RMVPE
2020

2121
now_dir = os.getcwd()
2222
sys.path.append(now_dir)
@@ -108,24 +108,23 @@ def get_f0(
108108
f0[pd < 0.1] = 0
109109
f0 = f0[0].cpu().numpy()
110110
elif f0_method == "rmvpe":
111-
if not hasattr(self, "model_rmvpe"):
112-
from infer.lib.rmvpe import RMVPE
113-
111+
if not hasattr(self, "rmvpe"):
114112
logger.info(
115113
"Loading rmvpe model %s" % "%s/rmvpe.pt" % os.environ["rmvpe_root"]
116114
)
117-
self.model_rmvpe = RMVPE(
115+
self.rmvpe = RMVPE(
118116
"%s/rmvpe.pt" % os.environ["rmvpe_root"],
119117
is_half=self.is_half,
120118
device=self.device,
121119
# use_jit=self.config.use_jit,
122120
)
123-
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
121+
f0 = self.rmvpe.compute_f0(x, filter_radius=0.03)
124122

125123
if "privateuseone" in str(self.device): # clean ortruntime memory
126-
del self.model_rmvpe.model
127-
del self.model_rmvpe
124+
del self.rmvpe.model
125+
del self.rmvpe
128126
logger.info("Cleaning ortruntime memory")
127+
129128
elif f0_method == "fcpe":
130129
if not hasattr(self, "model_fcpe"):
131130
from torchfcpe import spawn_bundled_infer_model

rvc/f0/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from .f0 import F0Predictor
2+
13
from .dio import Dio
24
from .harvest import Harvest
35
from .pm import PM
4-
from .f0 import F0Predictor
6+
from .rmvpe import RMVPE

rvc/f0/dio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any, Optional, Union
22

33
import numpy as np
44
import pyworld
@@ -14,7 +14,7 @@ def compute_f0(
1414
self,
1515
wav: np.ndarray[Any, np.dtype],
1616
p_len: Optional[int] = None,
17-
filter_radius: Optional[int] = None,
17+
filter_radius: Optional[Union[int, float]] = None,
1818
):
1919
if p_len is None:
2020
p_len = wav.shape[0] // self.hop_length

rvc/f0/f0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any, Optional, Union
22

33
import numpy as np
44

@@ -14,7 +14,7 @@ def compute_f0(
1414
self,
1515
wav: np.ndarray[Any, np.dtype],
1616
p_len: Optional[int] = None,
17-
filter_radius: Optional[int] = None,
17+
filter_radius: Optional[Union[int, float]] = None,
1818
): ...
1919

2020
def interpolate_f0(self, f0: np.ndarray[Any, np.dtype]):

rvc/f0/harvest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any, Optional, Union
22

33
import numpy as np
44
import pyworld
@@ -15,7 +15,7 @@ def compute_f0(
1515
self,
1616
wav: np.ndarray[Any, np.dtype],
1717
p_len: Optional[int] = None,
18-
filter_radius: Optional[int] = None,
18+
filter_radius: Optional[Union[int, float]] = None,
1919
):
2020
if p_len is None:
2121
p_len = wav.shape[0] // self.hop_length
Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,60 @@
11
from io import BytesIO
22
import os
3-
from typing import List, Optional, Tuple, Union
3+
from typing import Any, Optional, Union
4+
45
import numpy as np
56
import torch
7+
import torch.nn.functional as F
68

79
from infer.lib import jit
810

9-
import torch.nn.functional as F
10-
11-
import logging
11+
from .mel import MelSpectrogram
12+
from .e2e import E2E
13+
from .f0 import F0Predictor
1214

13-
logger = logging.getLogger(__name__)
1415

15-
from rvc.f0.mel import MelSpectrogram
16-
from rvc.f0.e2e import E2E
16+
class RMVPE(F0Predictor):
17+
def __init__(
18+
self,
19+
model_path: str,
20+
is_half: bool,
21+
device: str,
22+
use_jit=False,
23+
):
24+
hop_length=160
25+
f0_min=30
26+
f0_max=8000
27+
sampling_rate=16000
1728

29+
super().__init__(hop_length, f0_min, f0_max, sampling_rate)
1830

19-
class RMVPE:
20-
def __init__(self, model_path: str, is_half, device=None, use_jit=False):
21-
self.resample_kernel = {}
22-
self.resample_kernel = {}
2331
self.is_half = is_half
32+
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
33+
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
34+
2435
if device is None:
2536
device = "cuda:0" if torch.cuda.is_available() else "cpu"
2637
self.device = device
38+
2739
self.mel_extractor = MelSpectrogram(
2840
is_half=is_half,
2941
n_mel_channels=128,
30-
sampling_rate=16000,
42+
sampling_rate=sampling_rate,
3143
win_length=1024,
32-
hop_length=160,
33-
mel_fmin=30,
34-
mel_fmax=8000,
44+
hop_length=hop_length,
45+
mel_fmin=f0_min,
46+
mel_fmax=f0_max,
3547
device=device,
3648
).to(device)
49+
3750
if "privateuseone" in str(device):
3851
import onnxruntime as ort
3952

40-
ort_session = ort.InferenceSession(
53+
self.model = ort.InferenceSession(
4154
"%s/rmvpe.onnx" % os.environ["rmvpe_root"],
4255
providers=["DmlExecutionProvider"],
4356
)
44-
self.model = ort_session
4557
else:
46-
if str(self.device) == "cuda":
47-
self.device = torch.device("cuda:0")
48-
4958
def get_jit_model():
5059
jit_model_path = model_path.rstrip(".pth")
5160
jit_model_path += ".half.jit" if is_half else ".jit"
@@ -83,74 +92,40 @@ def get_default_model():
8392

8493
if use_jit:
8594
if is_half and "cpu" in str(self.device):
86-
logger.warning(
87-
"Use default rmvpe model. \
88-
Jit is not supported on the CPU for half floating point"
89-
)
9095
self.model = get_default_model()
9196
else:
9297
self.model = get_jit_model()
9398
else:
9499
self.model = get_default_model()
95100

96101
self.model = self.model.to(device)
97-
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
98-
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
99-
100-
def mel2hidden(self, mel):
101-
with torch.no_grad():
102-
n_frames = mel.shape[-1]
103-
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
104-
if n_pad > 0:
105-
mel = F.pad(mel, (0, n_pad), mode="constant")
106-
if "privateuseone" in str(self.device):
107-
onnx_input_name = self.model.get_inputs()[0].name
108-
onnx_outputs_names = self.model.get_outputs()[0].name
109-
hidden = self.model.run(
110-
[onnx_outputs_names],
111-
input_feed={onnx_input_name: mel.cpu().numpy()},
112-
)[0]
113-
else:
114-
mel = mel.half() if self.is_half else mel.float()
115-
hidden = self.model(mel)
116-
return hidden[:, :n_frames]
117102

118-
def decode(self, hidden, thred=0.03):
119-
cents_pred = self.to_local_average_cents(hidden, threshold=thred)
120-
f0 = 10 * (2 ** (cents_pred / 1200))
121-
f0[f0 == 10] = 0
122-
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
123-
return f0
124-
125-
def infer_from_audio(self, audio, threshold=0.03):
126-
# torch.cuda.synchronize()
127-
# t0 = ttime()
128-
if not torch.is_tensor(audio):
129-
audio = torch.from_numpy(audio)
103+
def compute_f0(
104+
self,
105+
wav: np.ndarray[Any, np.dtype],
106+
p_len: Optional[int] = None,
107+
filter_radius: Optional[Union[int, float]] = None,
108+
):
109+
if p_len is None:
110+
p_len = wav.shape[0] // self.hop_length
111+
if not torch.is_tensor(wav):
112+
wav = torch.from_numpy(wav)
130113
mel = self.mel_extractor(
131-
audio.float().to(self.device).unsqueeze(0), center=True
114+
wav.float().to(self.device).unsqueeze(0), center=True
132115
)
133-
# print(123123123,mel.device.type)
134-
# torch.cuda.synchronize()
135-
# t1 = ttime()
136-
hidden = self.mel2hidden(mel)
137-
# torch.cuda.synchronize()
138-
# t2 = ttime()
139-
# print(234234,hidden.device.type)
116+
hidden = self._mel2hidden(mel)
140117
if "privateuseone" not in str(self.device):
141118
hidden = hidden.squeeze(0).cpu().numpy()
142119
else:
143120
hidden = hidden[0]
144121
if self.is_half == True:
145122
hidden = hidden.astype("float32")
146123

147-
f0 = self.decode(hidden, thred=threshold)
148-
# torch.cuda.synchronize()
149-
# t3 = ttime()
150-
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
151-
return f0
124+
f0 = self._decode(hidden, thred=filter_radius)
152125

153-
def to_local_average_cents(self, salience, threshold=0.05):
126+
return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
127+
128+
def _to_local_average_cents(self, salience, threshold=0.05):
154129
center = np.argmax(salience, axis=1) # 帧长#index
155130
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
156131
center += 4
@@ -169,3 +144,28 @@ def to_local_average_cents(self, salience, threshold=0.05):
169144
maxx = np.max(salience, axis=1) # 帧长
170145
devided[maxx <= threshold] = 0
171146
return devided
147+
148+
def _mel2hidden(self, mel):
149+
with torch.no_grad():
150+
n_frames = mel.shape[-1]
151+
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
152+
if n_pad > 0:
153+
mel = F.pad(mel, (0, n_pad), mode="constant")
154+
if "privateuseone" in str(self.device):
155+
onnx_input_name = self.model.get_inputs()[0].name
156+
onnx_outputs_names = self.model.get_outputs()[0].name
157+
hidden = self.model.run(
158+
[onnx_outputs_names],
159+
input_feed={onnx_input_name: mel.cpu().numpy()},
160+
)[0]
161+
else:
162+
mel = mel.half() if self.is_half else mel.float()
163+
hidden = self.model(mel)
164+
return hidden[:, :n_frames]
165+
166+
def _decode(self, hidden, thred=0.03):
167+
cents_pred = self._to_local_average_cents(hidden, threshold=thred)
168+
f0 = 10 * (2 ** (cents_pred / 1200))
169+
f0[f0 == 10] = 0
170+
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
171+
return f0

0 commit comments

Comments
 (0)