Skip to content

Commit 1d56cc5

Browse files
committed
fix: Add weight whitelist support in torch 2.6
1 parent e1aeb16 commit 1d56cc5

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

configs/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ def has_xpu() -> bool:
129129
else:
130130
return False
131131

132+
@staticmethod
133+
def use_insecure_load():
134+
try:
135+
from fairseq.data.dictionary import Dictionary
136+
137+
logging.warning("Using insecure weight loading for fairseq dictionary")
138+
torch.serialization.add_safe_globals([Dictionary])
139+
except AttributeError:
140+
pass
141+
132142
def use_fp32_config(self):
133143
for config_file in version_config_list:
134144
self.json_config[config_file]["train"]["fp16_run"] = False
@@ -210,10 +220,16 @@ def device_config(self):
210220
else:
211221
if self.instead:
212222
logger.info(f"Use {self.instead} instead")
223+
213224
logger.info(
214225
"Half-precision floating-point: %s, device: %s"
215226
% (self.is_half, self.device)
216227
)
228+
229+
# Check if the pytorch is 2.6 or higher
230+
if tuple(map(int, torch.__version__.split("+")[0].split("."))) >= (2, 6, 0):
231+
self.use_insecure_load()
232+
217233
return x_pad, x_query, x_center, x_max
218234

219235

0 commit comments

Comments
 (0)