Skip to content

Commit ef9c8eb

Browse files
authored
fix: Add weight whitelist support for torch 2.6 (#110)
1 parent e1aeb16 commit ef9c8eb

File tree

7 files changed

+45
-21
lines changed

7 files changed

+45
-21
lines changed

configs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .config import singleton_variable, Config, CPUConfig
1+
from .config import Singleton, Config, CPUConfig

configs/config.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,16 @@
2222
]
2323

2424

25-
def singleton_variable(func):
26-
def wrapper(*args, **kwargs):
27-
if wrapper.instance is None:
28-
wrapper.instance = func(*args, **kwargs)
29-
return wrapper.instance
25+
class Singleton(type):
26+
_instances = {}
3027

31-
wrapper.instance = None
32-
return wrapper
28+
def __call__(cls, *args, **kwargs):
29+
if cls not in cls._instances:
30+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
31+
return cls._instances[cls]
3332

3433

35-
@singleton_variable
36-
class Config:
34+
class Config(metaclass=Singleton):
3735
def __init__(self):
3836
self.device = "cuda:0"
3937
self.is_half = True
@@ -129,6 +127,16 @@ def has_xpu() -> bool:
129127
else:
130128
return False
131129

130+
@staticmethod
131+
def use_insecure_load():
132+
try:
133+
from fairseq.data.dictionary import Dictionary
134+
135+
torch.serialization.add_safe_globals([Dictionary])
136+
logging.warning("Using insecure weight loading for fairseq dictionary")
137+
except AttributeError:
138+
pass
139+
132140
def use_fp32_config(self):
133141
for config_file in version_config_list:
134142
self.json_config[config_file]["train"]["fp16_run"] = False
@@ -210,15 +218,20 @@ def device_config(self):
210218
else:
211219
if self.instead:
212220
logger.info(f"Use {self.instead} instead")
221+
213222
logger.info(
214223
"Half-precision floating-point: %s, device: %s"
215224
% (self.is_half, self.device)
216225
)
226+
227+
# Check if the pytorch is 2.6 or higher
228+
if tuple(map(int, torch.__version__.split("+")[0].split("."))) >= (2, 6, 0):
229+
self.use_insecure_load()
230+
217231
return x_pad, x_query, x_center, x_max
218232

219233

220-
@singleton_variable
221-
class CPUConfig:
234+
class CPUConfig(metaclass=Singleton):
222235
def __init__(self):
223236
self.device = "cpu"
224237
self.is_half = False

i18n/i18n.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import locale
33
import os
4-
from configs import singleton_variable
4+
from configs import Singleton
55

66

77
def load_language_list(language):
@@ -10,8 +10,7 @@ def load_language_list(language):
1010
return language_list
1111

1212

13-
@singleton_variable
14-
class I18nAuto:
13+
class I18nAuto(metaclass=Singleton):
1514
def __init__(self, language=None):
1615
if language in ["Auto", None]:
1716
language = locale.getdefaultlocale(

infer/modules/train/extract_f0_print.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def go(self, paths, f0_method):
9696
# exp_dir=r"E:\codes\py39\dataset\mi-test"
9797
# n_p=16
9898
# f = open("%s/log_extract_f0.log"%exp_dir, "w")
99+
100+
from configs import Config
101+
Config.use_insecure_load()
102+
99103
printt(" ".join(sys.argv))
100104
featureInput = FeatureInput(is_half, device)
101105
paths = []

infer/modules/train/extract_feature_print.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@
2323
os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
2424
version = sys.argv[6]
2525
is_half = sys.argv[7].lower() == "true"
26+
2627
import fairseq
2728
import numpy as np
2829
import torch
2930
import torch.nn.functional as F
3031

32+
from configs import Config
33+
Config.use_insecure_load()
34+
3135
if "privateuseone" not in device:
3236
device = "cpu"
3337
if torch.cuda.is_available():

infer/modules/train/preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,6 @@ def preprocess_trainset(inp_root, sr, n_p, exp_dir, per):
142142

143143

144144
if __name__ == "__main__":
145+
from configs import Config
146+
Config.use_insecure_load()
145147
preprocess_trainset(inp_root, sr, n_p, exp_dir, per)

infer/modules/vc/hash.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import torch
33
import hashlib
44
import pathlib
5+
6+
from functools import lru_cache
57
from scipy.fft import fft
68
from pybase16384 import encode_to_string, decode_from_string
79

8-
from configs import CPUConfig, singleton_variable
10+
from configs import CPUConfig
911
from rvc.synthesizer import get_synthesizer
1012

1113
from .pipeline import Pipeline
@@ -29,27 +31,27 @@ def __exit__(self, type, value, traceback):
2931
expand_factor = 65536 * 8
3032

3133

32-
@singleton_variable
34+
@lru_cache(None) # None 表示无限缓存
3335
def original_audio_storage():
3436
return np.load(pathlib.Path(__file__).parent / "lgdsng.npz")
3537

3638

37-
@singleton_variable
39+
@lru_cache(None)
3840
def original_audio():
3941
return original_audio_storage()["a"]
4042

4143

42-
@singleton_variable
44+
@lru_cache(None)
4345
def original_audio_time_minus():
4446
return original_audio_storage()["t"]
4547

4648

47-
@singleton_variable
49+
@lru_cache(None)
4850
def original_audio_freq_minus():
4951
return original_audio_storage()["f"]
5052

5153

52-
@singleton_variable
54+
@lru_cache(None)
5355
def original_rmvpe_f0():
5456
x = original_audio_storage()
5557
return x["pitch"], x["pitchf"]

0 commit comments

Comments
 (0)