Skip to content

Commit 77b371d

Browse files
committed
optimize(f0): move some f0s into rvc.f0
1 parent d44a942 commit 77b371d

File tree

15 files changed

+91
-185
lines changed

15 files changed

+91
-185
lines changed

infer/lib/jit/rmvpe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu")):
5-
from infer.lib.rmvpe import E2E
5+
from rvc.f0.e2e import E2E
66

77
model = E2E(4, 1, (2, 2))
88
ckpt = torch.load(model_path, map_location=device)

infer/lib/rmvpe.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,6 @@
66

77
from infer.lib import jit
88

9-
try:
10-
# Fix "Torch not compiled with CUDA enabled"
11-
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
12-
13-
if torch.xpu.is_available():
14-
from infer.modules.ipex import ipex_init
15-
16-
ipex_init()
17-
except Exception: # pylint: disable=broad-exception-caught
18-
pass
19-
import torch.nn as nn
209
import torch.nn.functional as F
2110

2211
import logging
@@ -127,13 +116,13 @@ def mel2hidden(self, mel):
127116
return hidden[:, :n_frames]
128117

129118
def decode(self, hidden, thred=0.03):
130-
cents_pred = self.to_local_average_cents(hidden, thred=thred)
119+
cents_pred = self.to_local_average_cents(hidden, threshold=thred)
131120
f0 = 10 * (2 ** (cents_pred / 1200))
132121
f0[f0 == 10] = 0
133122
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
134123
return f0
135124

136-
def infer_from_audio(self, audio, thred=0.03):
125+
def infer_from_audio(self, audio, threshold=0.03):
137126
# torch.cuda.synchronize()
138127
# t0 = ttime()
139128
if not torch.is_tensor(audio):
@@ -155,17 +144,15 @@ def infer_from_audio(self, audio, thred=0.03):
155144
if self.is_half == True:
156145
hidden = hidden.astype("float32")
157146

158-
f0 = self.decode(hidden, thred=thred)
147+
f0 = self.decode(hidden, thred=threshold)
159148
# torch.cuda.synchronize()
160149
# t3 = ttime()
161150
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
162151
return f0
163152

164-
def to_local_average_cents(self, salience, thred=0.05):
165-
# t0 = ttime()
153+
def to_local_average_cents(self, salience, threshold=0.05):
166154
center = np.argmax(salience, axis=1) # 帧长#index
167155
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
168-
# t1 = ttime()
169156
center += 4
170157
todo_salience = []
171158
todo_cents_mapping = []
@@ -174,15 +161,11 @@ def to_local_average_cents(self, salience, thred=0.05):
174161
for idx in range(salience.shape[0]):
175162
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
176163
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
177-
# t2 = ttime()
178164
todo_salience = np.array(todo_salience) # 帧长,9
179165
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
180166
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
181167
weight_sum = np.sum(todo_salience, 1) # 帧长
182168
devided = product_sum / weight_sum # 帧长
183-
# t3 = ttime()
184169
maxx = np.max(salience, axis=1) # 帧长
185-
devided[maxx <= thred] = 0
186-
# t4 = ttime()
187-
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
170+
devided[maxx <= threshold] = 0
188171
return devided

infer/modules/train/extract/extract_f0_print.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def compute_f0(self, path, f0_method):
8989
self.model_rmvpe = RMVPE(
9090
"assets/rmvpe/rmvpe.pt", is_half=False, device="cpu"
9191
)
92-
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
92+
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
9393
return f0
9494

9595
def coarse_f0(self, f0):

infer/modules/train/extract/extract_f0_rmvpe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def compute_f0(self, path, f0_method):
5252
self.model_rmvpe = RMVPE(
5353
"assets/rmvpe/rmvpe.pt", is_half=is_half, device="cuda"
5454
)
55-
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
55+
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
5656
return f0
5757

5858
def coarse_f0(self, f0):

infer/modules/train/extract/extract_f0_rmvpe_dml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def compute_f0(self, path, f0_method):
5050
self.model_rmvpe = RMVPE(
5151
"assets/rmvpe/rmvpe.pt", is_half=False, device=device
5252
)
53-
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
53+
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
5454
return f0
5555

5656
def coarse_f0(self, f0):

infer/modules/train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from torch.utils.data import DataLoader
4848
from torch.utils.tensorboard import SummaryWriter
4949

50-
from rvc import utils
50+
from rvc.layers import utils
5151
from infer.lib.train.data_utils import (
5252
DistributedBucketSampler,
5353
TextAudioCollate,

infer/modules/vc/pipeline.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,24 @@
55

66
logger = logging.getLogger(__name__)
77

8-
from functools import lru_cache
98
from time import time
109

1110
import faiss
1211
import librosa
1312
import numpy as np
14-
import parselmouth
1513
import pyworld
1614
import torch
1715
import torch.nn.functional as F
1816
import torchcrepe
1917
from scipy import signal
2018

19+
from rvc.f0 import PM, Harvest
20+
2121
now_dir = os.getcwd()
2222
sys.path.append(now_dir)
2323

2424
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
2525

26-
input_audio_path2wav = {}
27-
28-
29-
@lru_cache
30-
def cache_harvest_f0(f0_cache_key, fs, f0max, f0min, frame_period):
31-
audio = input_audio_path2wav[f0_cache_key]
32-
f0, t = pyworld.harvest(
33-
audio,
34-
fs=fs,
35-
f0_ceil=f0max,
36-
f0_floor=f0min,
37-
frame_period=frame_period,
38-
)
39-
f0 = pyworld.stonemask(audio, f0, t, fs)
40-
return f0
41-
4226

4327
def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比
4428
# print(data1.max(),data2.max())
@@ -90,37 +74,18 @@ def get_f0(
9074
filter_radius,
9175
inp_f0=None,
9276
):
93-
global input_audio_path2wav
94-
time_step = self.window / self.sr * 1000
9577
f0_min = 50
9678
f0_max = 1100
9779
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
9880
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
9981
if f0_method == "pm":
100-
f0 = (
101-
parselmouth.Sound(x, self.sr)
102-
.to_pitch_ac(
103-
time_step=time_step / 1000,
104-
voicing_threshold=0.6,
105-
pitch_floor=f0_min,
106-
pitch_ceiling=f0_max,
107-
)
108-
.selected_array["frequency"]
109-
)
110-
pad_size = (p_len - len(f0) + 1) // 2
111-
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
112-
f0 = np.pad(
113-
f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
114-
)
82+
if not hasattr(self, "pm"):
83+
self.pm = PM(self.window, f0_min, f0_max, self.sr)
84+
f0 = self.pm.compute_f0(x, p_len=p_len)
11585
elif f0_method == "harvest":
116-
from hashlib import md5
117-
118-
f0_cache_key = md5(x.tobytes()).digest()
119-
input_audio_path2wav[f0_cache_key] = x.astype(np.double)
120-
f0 = cache_harvest_f0(f0_cache_key, self.sr, f0_max, f0_min, 10)
121-
del input_audio_path2wav[f0_cache_key]
122-
if filter_radius > 2:
123-
f0 = signal.medfilt(f0, 3)
86+
if not hasattr(self, "harvest"):
87+
self.harvest = Harvest(self.window, f0_min, f0_max, self.sr)
88+
f0 = self.harvest.compute_f0(x, p_len=p_len, filter_radius=filter_radius)
12489
elif f0_method == "crepe":
12590
model = "full"
12691
# Pick a batch size that doesn't cause memory errors on your gpu
@@ -155,7 +120,7 @@ def get_f0(
155120
device=self.device,
156121
# use_jit=self.config.use_jit,
157122
)
158-
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
123+
f0 = self.model_rmvpe.infer_from_audio(x, threshold=0.03)
159124

160125
if "privateuseone" in str(self.device): # clean ortruntime memory
161126
del self.model_rmvpe.model

rvc/f0/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .dio import Dio
2+
from .harvest import Harvest
3+
from .pm import PM
4+
from .f0 import F0Predictor

rvc/onnx/f0/dio.py renamed to rvc/f0/dio.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,15 @@
66
from .f0 import F0Predictor
77

88

9-
class DioF0Predictor(F0Predictor):
9+
class Dio(F0Predictor):
1010
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
1111
super().__init__(hop_length, f0_min, f0_max, sampling_rate)
1212

13-
def compute_f0(self, wav: np.ndarray[Any, np.dtype], p_len: Optional[int] = None):
14-
if p_len is None:
15-
p_len = wav.shape[0] // self.hop_length
16-
f0, t = pyworld.dio(
17-
wav.astype(np.double),
18-
fs=self.sampling_rate,
19-
f0_floor=self.f0_min,
20-
f0_ceil=self.f0_max,
21-
frame_period=1000 * self.hop_length / self.sampling_rate,
22-
)
23-
f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
24-
for index, pitch in enumerate(f0):
25-
f0[index] = round(pitch, 1)
26-
return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
27-
28-
def compute_f0_uv(
29-
self, wav: np.ndarray[Any, np.dtype], p_len: Optional[int] = None
13+
def compute_f0(
14+
self,
15+
wav: np.ndarray[Any, np.dtype],
16+
p_len: Optional[int] = None,
17+
filter_radius: Optional[int] = None,
3018
):
3119
if p_len is None:
3220
p_len = wav.shape[0] // self.hop_length
@@ -40,4 +28,4 @@ def compute_f0_uv(
4028
f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
4129
for index, pitch in enumerate(f0):
4230
f0[index] = round(pitch, 1)
43-
return self.interpolate_f0(self.resize_f0(f0, p_len))
31+
return self.interpolate_f0(self.resize_f0(f0, p_len))[0]

rvc/onnx/f0/f0.py renamed to rvc/f0/f0.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
1111
self.sampling_rate = sampling_rate
1212

1313
def compute_f0(
14-
self, wav: np.ndarray[Any, np.dtype], p_len: Optional[int] = None
15-
): ...
16-
17-
def compute_f0_uv(
18-
self, wav: np.ndarray[Any, np.dtype], p_len: Optional[int] = None
14+
self,
15+
wav: np.ndarray[Any, np.dtype],
16+
p_len: Optional[int] = None,
17+
filter_radius: Optional[int] = None,
1918
): ...
2019

2120
def interpolate_f0(self, f0: np.ndarray[Any, np.dtype]):

0 commit comments

Comments
 (0)