|
22 | 22 | ] |
23 | 23 |
|
24 | 24 |
|
25 | | -def singleton_variable(func): |
26 | | - def wrapper(*args, **kwargs): |
27 | | - if wrapper.instance is None: |
28 | | - wrapper.instance = func(*args, **kwargs) |
29 | | - return wrapper.instance |
| 25 | +class Singleton(type): |
| 26 | + _instances = {} |
30 | 27 |
|
31 | | - wrapper.instance = None |
32 | | - return wrapper |
| 28 | + def __call__(cls, *args, **kwargs): |
| 29 | + if cls not in cls._instances: |
| 30 | + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) |
| 31 | + return cls._instances[cls] |
33 | 32 |
|
34 | 33 |
|
35 | | -@singleton_variable |
36 | | -class Config: |
| 34 | +class Config(metaclass=Singleton): |
37 | 35 | def __init__(self): |
38 | 36 | self.device = "cuda:0" |
39 | 37 | self.is_half = True |
@@ -129,6 +127,16 @@ def has_xpu() -> bool: |
129 | 127 | else: |
130 | 128 | return False |
131 | 129 |
|
| 130 | + @staticmethod |
| 131 | + def use_insecure_load(): |
| 132 | + try: |
| 133 | + from fairseq.data.dictionary import Dictionary |
| 134 | + |
| 135 | + torch.serialization.add_safe_globals([Dictionary]) |
| 136 | + logging.warning("Using insecure weight loading for fairseq dictionary") |
| 137 | + except AttributeError: |
| 138 | + pass |
| 139 | + |
132 | 140 | def use_fp32_config(self): |
133 | 141 | for config_file in version_config_list: |
134 | 142 | self.json_config[config_file]["train"]["fp16_run"] = False |
@@ -210,15 +218,20 @@ def device_config(self): |
210 | 218 | else: |
211 | 219 | if self.instead: |
212 | 220 | logger.info(f"Use {self.instead} instead") |
| 221 | + |
213 | 222 | logger.info( |
214 | 223 | "Half-precision floating-point: %s, device: %s" |
215 | 224 | % (self.is_half, self.device) |
216 | 225 | ) |
| 226 | + |
| 227 | + # Check if the pytorch is 2.6 or higher |
| 228 | + if tuple(map(int, torch.__version__.split("+")[0].split("."))) >= (2, 6, 0): |
| 229 | + self.use_insecure_load() |
| 230 | + |
217 | 231 | return x_pad, x_query, x_center, x_max |
218 | 232 |
|
219 | 233 |
|
220 | | -@singleton_variable |
221 | | -class CPUConfig: |
| 234 | +class CPUConfig(metaclass=Singleton): |
222 | 235 | def __init__(self): |
223 | 236 | self.device = "cpu" |
224 | 237 | self.is_half = False |
|
0 commit comments