|
1 | 1 | import asyncio
|
| 2 | +import contextlib |
2 | 3 | import copy
|
3 | 4 | import os
|
4 | 5 | import pickle
|
|
11 | 12 | import numpy as np
|
12 | 13 | from async_substrate_interface.errors import SubstrateRequestException
|
13 | 14 | from numpy.typing import NDArray
|
| 15 | +from packaging import version |
14 | 16 |
|
15 | 17 | from bittensor.core import settings
|
16 | 18 | from bittensor.core.chain_data import (
|
@@ -143,6 +145,27 @@ def latest_block_path(dir_path: str) -> str:
|
143 | 145 | return latest_file_full_path
|
144 | 146 |
|
145 | 147 |
|
| 148 | +def safe_globals(): |
| 149 | + """ |
| 150 | + Context manager to load torch files for version 2.6+ |
| 151 | + """ |
| 152 | + if version.parse(torch.__version__).release < version.parse("2.6").release: |
| 153 | + return contextlib.nullcontext() |
| 154 | + |
| 155 | + np_core = ( |
| 156 | + np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core |
| 157 | + ) |
| 158 | + allow_list = [ |
| 159 | + np_core.multiarray._reconstruct, |
| 160 | + np.ndarray, |
| 161 | + np.dtype, |
| 162 | + type(np.dtype(np.uint32)), |
| 163 | + np.dtypes.Float32DType, |
| 164 | + bytes, |
| 165 | + ] |
| 166 | + return torch.serialization.safe_globals(allow_list) |
| 167 | + |
| 168 | + |
146 | 169 | class MetagraphMixin(ABC):
|
147 | 170 | """
|
148 | 171 | The metagraph class is a core component of the Bittensor network, representing the neural graph that forms the
|
@@ -1124,7 +1147,8 @@ def load_from_path(self, dir_path: str) -> "MetagraphMixin":
|
1124 | 1147 | """
|
1125 | 1148 |
|
1126 | 1149 | graph_file = latest_block_path(dir_path)
|
1127 |
| - state_dict = torch.load(graph_file) |
| 1150 | + with safe_globals(): |
| 1151 | + state_dict = torch.load(graph_file) |
1128 | 1152 | self.n = torch.nn.Parameter(state_dict["n"], requires_grad=False)
|
1129 | 1153 | self.block = torch.nn.Parameter(state_dict["block"], requires_grad=False)
|
1130 | 1154 | self.uids = torch.nn.Parameter(state_dict["uids"], requires_grad=False)
|
@@ -1256,7 +1280,8 @@ def load_from_path(self, dir_path: str) -> "MetagraphMixin":
|
1256 | 1280 | try:
|
1257 | 1281 | import torch as real_torch
|
1258 | 1282 |
|
1259 |
| - state_dict = real_torch.load(graph_filename) |
| 1283 | + with safe_globals(): |
| 1284 | + state_dict = real_torch.load(graph_filename) |
1260 | 1285 | for key in METAGRAPH_STATE_DICT_NDARRAY_KEYS:
|
1261 | 1286 | state_dict[key] = state_dict[key].detach().numpy()
|
1262 | 1287 | del real_torch
|
|
0 commit comments