Skip to content

Commit c51a73f

Browse files
committed
optimize(infer): move jit into rvc
1 parent e936e24 commit c51a73f

File tree

8 files changed

+145
-231
lines changed

8 files changed

+145
-231
lines changed

infer/lib/jit/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

infer/lib/jit/utils.py

Lines changed: 0 additions & 163 deletions
This file was deleted.

infer/lib/rtrvc.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -125,27 +125,10 @@ def set_default_model():
125125
self.net_g = self.net_g.float()
126126

127127
def set_jit_model():
128-
jit_pth_path = self.pth_path.rstrip(".pth")
129-
jit_pth_path += ".half.jit" if self.is_half else ".jit"
130-
reload = False
131-
if str(self.device) == "cuda":
132-
self.device = torch.device("cuda:0")
133-
if os.path.exists(jit_pth_path):
134-
cpt = jit.load(jit_pth_path)
135-
model_device = cpt["device"]
136-
if model_device != str(self.device):
137-
reload = True
138-
else:
139-
reload = True
140-
141-
if reload:
142-
cpt = jit.synthesizer_jit_export(
143-
self.pth_path,
144-
"script",
145-
None,
146-
device=self.device,
147-
is_half=self.is_half,
148-
)
128+
from rvc.jit import get_jit_model
129+
from rvc.synthesizer import synthesizer_jit_export
130+
131+
cpt = get_jit_model(self.pth_path, self.is_half, synthesizer_jit_export)
149132

150133
self.tgt_sr = cpt["config"][-1]
151134
self.if_f0 = cpt.get("f0", 1)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import torch
22

3-
4-
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu")):
3+
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False):
54
from rvc.f0.e2e import E2E
65

76
model = E2E(4, 1, (2, 2))
87
ckpt = torch.load(model_path, map_location=device)
98
model.load_state_dict(ckpt)
109
model.eval()
10+
if is_half:
11+
model = model.half()
1112
model = model.to(device)
1213
return model

rvc/f0/rmvpe.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,35 @@
66
import torch
77
import torch.nn.functional as F
88

9-
from infer.lib import jit
9+
from rvc.jit import load_inputs, get_jit_model, export_jit_model, save_pickle
1010

1111
from .mel import MelSpectrogram
12-
from .e2e import E2E
1312
from .f0 import F0Predictor
14-
13+
from .models import get_rmvpe
14+
15+
16+
def rmvpe_jit_export(
17+
model_path: str,
18+
mode: str = "script",
19+
inputs_path: str = None,
20+
save_path: str = None,
21+
device=torch.device("cpu"),
22+
is_half=False,
23+
):
24+
if not save_path:
25+
save_path = model_path.rstrip(".pth")
26+
save_path += ".half.jit" if is_half else ".jit"
27+
if "cuda" in str(device) and ":" not in str(device):
28+
device = torch.device("cuda:0")
29+
30+
model = get_rmvpe(model_path, device, is_half)
31+
inputs = None
32+
if mode == "trace":
33+
inputs = load_inputs(inputs_path, device, is_half)
34+
ckpt = export_jit_model(model, mode, inputs, device, is_half)
35+
ckpt["device"] = str(device)
36+
save_pickle(ckpt, save_path)
37+
return ckpt
1538

1639
class RMVPE(F0Predictor):
1740
def __init__(
@@ -57,51 +80,16 @@ def __init__(
5780
providers=["DmlExecutionProvider"],
5881
)
5982
else:
60-
61-
def get_jit_model():
62-
jit_model_path = model_path.rstrip(".pth")
63-
jit_model_path += ".half.jit" if is_half else ".jit"
64-
ckpt = None
65-
if os.path.exists(jit_model_path):
66-
ckpt = jit.load(jit_model_path)
67-
model_device = ckpt["device"]
68-
if model_device != str(self.device):
69-
del ckpt
70-
ckpt = None
71-
72-
if ckpt is None:
73-
ckpt = jit.rmvpe_jit_export(
74-
model_path=model_path,
75-
mode="script",
76-
inputs_path=None,
77-
save_path=jit_model_path,
78-
device=self.device,
79-
is_half=is_half,
80-
)
81-
83+
def rmvpe_jit_model():
84+
ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export)
8285
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
86+
model = model.to(self.device)
8387
return model
8488

85-
def get_default_model():
86-
model = E2E(4, 1, (2, 2))
87-
ckpt = torch.load(model_path, map_location="cpu")
88-
model.load_state_dict(ckpt)
89-
model.eval()
90-
if is_half:
91-
model = model.half()
92-
else:
93-
model = model.float()
94-
return model
95-
96-
if use_jit:
97-
if is_half and "cpu" in str(self.device):
98-
self.model = get_default_model()
99-
else:
100-
self.model = get_jit_model()
89+
if use_jit and not (is_half and "cpu" in str(self.device)):
90+
self.model = rmvpe_jit_model()
10191
else:
102-
self.model = get_default_model()
103-
104-
self.model = self.model.to(self.device)
92+
self.model = get_rmvpe(model_path, self.device, is_half)
10593

10694
def compute_f0(
10795
self,

rvc/jit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle

rvc/jit/jit.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pickle
2+
from io import BytesIO
3+
from collections import OrderedDict
4+
import os
5+
6+
import torch
7+
8+
9+
def load_pickle(path: str):
10+
with open(path, "rb") as f:
11+
return pickle.load(f)
12+
13+
14+
def save_pickle(ckpt: dict, save_path: str):
15+
with open(save_path, "wb") as f:
16+
pickle.dump(ckpt, f)
17+
18+
def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False):
19+
parm = torch.load(path, map_location=torch.device("cpu"))
20+
for key in parm.keys():
21+
parm[key] = parm[key].to(device)
22+
if is_half and parm[key].dtype == torch.float32:
23+
parm[key] = parm[key].half()
24+
elif not is_half and parm[key].dtype == torch.float16:
25+
parm[key] = parm[key].float()
26+
return parm
27+
28+
def export_jit_model(
29+
model: torch.nn.Module,
30+
mode: str = "trace",
31+
inputs: dict = None,
32+
device=torch.device("cpu"),
33+
is_half: bool = False,
34+
) -> dict:
35+
model = model.half() if is_half else model.float()
36+
model.eval()
37+
if mode == "trace":
38+
assert inputs is not None
39+
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
40+
elif mode == "script":
41+
model_jit = torch.jit.script(model)
42+
model_jit.to(device)
43+
model_jit = model_jit.half() if is_half else model_jit.float()
44+
buffer = BytesIO()
45+
# model_jit=model_jit.cpu()
46+
torch.jit.save(model_jit, buffer)
47+
del model_jit
48+
cpt = OrderedDict()
49+
cpt["model"] = buffer.getvalue()
50+
cpt["is_half"] = is_half
51+
return cpt
52+
53+
54+
def get_jit_model(model_path: str, is_half: bool, device: str, exporter):
55+
jit_model_path = model_path.rstrip(".pth")
56+
jit_model_path += ".half.jit" if is_half else ".jit"
57+
ckpt = None
58+
59+
if os.path.exists(jit_model_path):
60+
ckpt = load_pickle(jit_model_path)
61+
model_device = ckpt["device"]
62+
if model_device != str(device):
63+
del ckpt
64+
ckpt = None
65+
66+
if ckpt is None:
67+
ckpt = exporter(
68+
model_path=model_path,
69+
mode="script",
70+
inputs_path=None,
71+
save_path=jit_model_path,
72+
device=device,
73+
is_half=is_half,
74+
)
75+
76+
return ckpt

0 commit comments

Comments
 (0)