File tree Expand file tree Collapse file tree 4 files changed +33
-2
lines changed Expand file tree Collapse file tree 4 files changed +33
-2
lines changed Original file line number Diff line number Diff line change 37
37
from lightning .fabric .strategies .registry import _StrategyRegistry
38
38
from lightning .fabric .strategies .strategy import _Sharded
39
39
from lightning .fabric .utilities .distributed import log
40
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_6
40
41
from lightning .fabric .utilities .load import _move_state_into
41
42
from lightning .fabric .utilities .rank_zero import rank_zero_info , rank_zero_warn
42
43
from lightning .fabric .utilities .seed import reset_seed
47
48
from torch .optim .lr_scheduler import _LRScheduler
48
49
49
50
_DEEPSPEED_AVAILABLE = RequirementCache ("deepspeed" )
51
+ _DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache ("deepspeed>=0.16.0" )
50
52
51
53
52
54
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
@@ -239,6 +241,19 @@ def __init__(
239
241
" Install it by running `pip install -U deepspeed`."
240
242
)
241
243
244
+ if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16 :
245
+ # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
246
+ # DeepSpeed added support for this behavior in version 0.16.0.
247
+ import deepspeed
248
+
249
+ deepspeed_version = deepspeed .__version__
250
+
251
+ raise ImportError (
252
+ f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
253
+ f"Detected DeepSpeed version: { deepspeed_version } . "
254
+ "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
255
+ )
256
+
242
257
super ().__init__ (
243
258
accelerator = accelerator ,
244
259
parallel_devices = parallel_devices ,
Original file line number Diff line number Diff line change 36
36
_TORCH_GREATER_EQUAL_2_4_1 = compare_version ("torch" , operator .ge , "2.4.1" )
37
37
_TORCH_GREATER_EQUAL_2_5 = compare_version ("torch" , operator .ge , "2.5.0" )
38
38
_TORCH_LESS_EQUAL_2_6 = compare_version ("torch" , operator .le , "2.6.0" )
39
+ _TORCH_GREATER_EQUAL_2_6 = compare_version ("torch" , operator .ge , "2.6.0" )
39
40
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version ("torchmetrics" , operator .ge , "1.0.0" )
40
41
_PYTHON_GREATER_EQUAL_3_10_0 = (sys .version_info .major , sys .version_info .minor ) >= (3 , 10 )
Original file line number Diff line number Diff line change 35
35
from lightning .fabric .strategies import _StrategyRegistry
36
36
from lightning .fabric .strategies .deepspeed import (
37
37
_DEEPSPEED_AVAILABLE ,
38
+ _DEEPSPEED_GREATER_EQUAL_0_16 ,
38
39
_format_precision_config ,
39
40
_validate_checkpoint_directory ,
40
41
_validate_device_index_selection ,
41
42
)
43
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_6
42
44
from lightning .fabric .utilities .optimizer import _optimizers_to_device
43
45
from lightning .fabric .utilities .seed import reset_seed
44
46
from lightning .fabric .utilities .types import _PATH
@@ -262,6 +264,19 @@ def __init__(
262
264
" Install it by running `pip install -U deepspeed`."
263
265
)
264
266
267
+ if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16 :
268
+ # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
269
+ # DeepSpeed added support for this behavior in version 0.16.0.
270
+ import deepspeed
271
+
272
+ deepspeed_version = deepspeed .__version__
273
+
274
+ raise ImportError (
275
+ f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
276
+ f"Detected DeepSpeed version: { deepspeed_version } . "
277
+ "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
278
+ )
279
+
265
280
super ().__init__ (
266
281
accelerator = accelerator ,
267
282
parallel_devices = parallel_devices ,
Original file line number Diff line number Diff line change @@ -93,10 +93,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(
93
93
]
94
94
checkpoint_dir = ds_checkpoint_dir (checkpoint_dir )
95
95
optim_files = get_optim_files (checkpoint_dir )
96
- optim_state = torch .load (optim_files [0 ], map_location = CPU_DEVICE )
96
+ optim_state = torch .load (optim_files [0 ], map_location = CPU_DEVICE , weights_only = False )
97
97
zero_stage = optim_state ["optimizer_state_dict" ]["zero_stage" ]
98
98
model_file = get_model_state_file (checkpoint_dir , zero_stage )
99
- client_state = torch .load (model_file , map_location = CPU_DEVICE )
99
+ client_state = torch .load (model_file , map_location = CPU_DEVICE , weights_only = False )
100
100
client_state = {key : value for key , value in client_state .items () if key not in deepspeed_states }
101
101
# State dict keys will include reference to wrapper _LightningModuleWrapperBase in old checkpoints created in
102
102
# Lightning version < 2.1. Delete the `_forward_module` prefix before saving.
You can’t perform that action at this time.
0 commit comments