Skip to content

Commit 54f7ae0

Browse files
committed
optimize(jit): move hubert & synthesizer into rvc
1 parent 0efe48c commit 54f7ae0

File tree

15 files changed

+30
-27
lines changed

15 files changed

+30
-27
lines changed

infer/lib/jit/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
2-
from .synthesizer import get_synthesizer, get_synthesizer_ckpt

infer/lib/jit/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,18 @@ def to_jit_model(
4444
):
4545
model = None
4646
if model_type.lower() == "synthesizer":
47-
from .synthesizer import get_synthesizer
47+
from rvc.synthesizer import load_synthesizer
4848

49-
model, _ = get_synthesizer(model_path, device)
49+
model, _ = load_synthesizer(model_path, device)
5050
model.forward = model.infer
5151
elif model_type.lower() == "rmvpe":
5252
from .rmvpe import get_rmvpe
5353

5454
model = get_rmvpe(model_path, device)
5555
elif model_type.lower() == "hubert":
56-
from .hubert import get_hubert_model
56+
from rvc.hubert import get_hubert
5757

58-
model = get_hubert_model(model_path, device)
58+
model = get_hubert(model_path, device)
5959
model.forward = model.infer
6060
else:
6161
raise ValueError(f"No model type named {model_type}")
@@ -147,9 +147,9 @@ def synthesizer_jit_export(
147147
save_path += ".half.jit" if is_half else ".jit"
148148
if "cuda" in str(device) and ":" not in str(device):
149149
device = torch.device("cuda:0")
150-
from .synthesizer import get_synthesizer
150+
from rvc.synthesizer import load_synthesizer
151151

152-
model, cpt = get_synthesizer(model_path, device)
152+
model, cpt = load_synthesizer(model_path, device)
153153
assert isinstance(cpt, dict)
154154
model.forward = model.infer
155155
inputs = None

infer/lib/rmvpe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,16 +518,15 @@ def __init__(self, model_path: str, is_half, device=None, use_jit=False):
518518
def get_jit_model():
519519
jit_model_path = model_path.rstrip(".pth")
520520
jit_model_path += ".half.jit" if is_half else ".jit"
521-
reload = False
521+
ckpt = None
522522
if os.path.exists(jit_model_path):
523523
ckpt = jit.load(jit_model_path)
524524
model_device = ckpt["device"]
525525
if model_device != str(self.device):
526-
reload = True
527-
else:
528-
reload = True
526+
del ckpt
527+
ckpt = None
529528

530-
if reload:
529+
if ckpt is None:
531530
ckpt = jit.rmvpe_jit_export(
532531
model_path=model_path,
533532
mode="script",
@@ -536,6 +535,7 @@ def get_jit_model():
536535
device=device,
537536
is_half=is_half,
538537
)
538+
539539
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device)
540540
return model
541541

infer/lib/rtrvc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import torchcrepe
1717
from torchaudio.transforms import Resample
1818

19+
from rvc.synthesizer import load_synthesizer
20+
1921
now_dir = os.getcwd()
2022
sys.path.append(now_dir)
2123
from multiprocessing import Manager as M
@@ -113,7 +115,7 @@ def forward_dml(ctx, x, scale):
113115
self.net_g: nn.Module = None
114116

115117
def set_default_model():
116-
self.net_g, cpt = jit.get_synthesizer(self.pth_path, self.device)
118+
self.net_g, cpt = load_synthesizer(self.pth_path, self.device)
117119
self.tgt_sr = cpt["config"][-1]
118120
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
119121
self.if_f0 = cpt.get("f0", 1)

infer/modules/vc/hash.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pybase16384 import encode_to_string, decode_from_string
77

88
from configs import CPUConfig, singleton_variable
9-
from infer.lib.jit import get_synthesizer_ckpt
9+
from rvc.synthesizer import get_synthesizer
1010

1111
from .pipeline import Pipeline
1212
from .utils import load_hubert
@@ -132,7 +132,7 @@ def model_hash_ckpt(cpt):
132132
config = CPUConfig()
133133

134134
with TorchSeedContext(114514):
135-
net_g, cpt = get_synthesizer_ckpt(cpt, config.device)
135+
net_g, cpt = get_synthesizer(cpt, config.device)
136136
tgt_sr = cpt["config"][-1]
137137
if_f0 = cpt.get("f0", 1)
138138
version = cpt.get("version", "v1")

infer/modules/vc/modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from io import BytesIO
1111

1212
from infer.lib.audio import load_audio, wav2
13-
from infer.lib.jit import get_synthesizer_ckpt, get_synthesizer
13+
from rvc.synthesizer import get_synthesizer, load_synthesizer
1414
from .info import show_model_info
1515
from .pipeline import Pipeline
1616
from .utils import get_index_path_from_model, load_hubert
@@ -62,7 +62,7 @@ def get_vc(self, sid, *to_return_protect):
6262
elif torch.backends.mps.is_available():
6363
torch.mps.empty_cache()
6464
###楼下不这么折腾清理不干净
65-
self.net_g, self.cpt = get_synthesizer_ckpt(
65+
self.net_g, self.cpt = get_synthesizer(
6666
self.cpt, self.config.device
6767
)
6868
self.if_f0 = self.cpt.get("f0", 1)
@@ -88,7 +88,7 @@ def get_vc(self, sid, *to_return_protect):
8888
person = f'{os.getenv("weight_root")}/{sid}'
8989
logger.info(f"Loading: {person}")
9090

91-
self.net_g, self.cpt = get_synthesizer(person, self.config.device)
91+
self.net_g, self.cpt = load_synthesizer(person, self.config.device)
9292
self.tgt_sr = self.cpt["config"][-1]
9393
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
9494
self.if_f0 = self.cpt.get("f0", 1)

rvc/f0/__init__.py

Whitespace-only changes.

infer/lib/jit/hubert.py renamed to rvc/hubert.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import math
22
import random
33
from typing import Optional, Tuple
4+
45
from fairseq.checkpoint_utils import load_model_ensemble_and_task
6+
from fairseq.utils import index_put
57
import numpy as np
68
import torch
79
import torch.nn.functional as F
810

9-
# from fairseq.data.data_utils import compute_mask_indices
10-
from fairseq.utils import index_put
11-
1211

1312
# @torch.jit.script
1413
def pad_to_multiple(x, multiple, dim=-1, value=0):
@@ -263,7 +262,7 @@ def apply_mask(self, x, padding_mask, target_list):
263262
return x, mask_indices
264263

265264

266-
def get_hubert_model(
265+
def get_hubert(
267266
model_path="assets/hubert/hubert_base.pt", device=torch.device("cpu")
268267
):
269268
models, _, _ = load_model_ensemble_and_task(
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)