Skip to content

Commit aa95a4d

Browse files
committed
feat: enforce DeepSpeed version requirement for PyTorch >= 2.6 and update checkpoint loading behavior
1 parent e26cd46 commit aa95a4d

File tree

4 files changed

+33
-2
lines changed

4 files changed

+33
-2
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning.fabric.strategies.registry import _StrategyRegistry
3838
from lightning.fabric.strategies.strategy import _Sharded
3939
from lightning.fabric.utilities.distributed import log
40+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
4041
from lightning.fabric.utilities.load import _move_state_into
4142
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
4243
from lightning.fabric.utilities.seed import reset_seed
@@ -47,6 +48,7 @@
4748
from torch.optim.lr_scheduler import _LRScheduler
4849

4950
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
51+
_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
5052

5153

5254
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
@@ -239,6 +241,19 @@ def __init__(
239241
" Install it by running `pip install -U deepspeed`."
240242
)
241243

244+
if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
245+
# Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
246+
# DeepSpeed added support for this behavior in version 0.16.0.
247+
import deepspeed
248+
249+
deepspeed_version = deepspeed.__version__
250+
251+
raise ImportError(
252+
f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
253+
f"Detected DeepSpeed version: {deepspeed_version}. "
254+
"Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
255+
)
256+
242257
super().__init__(
243258
accelerator=accelerator,
244259
parallel_devices=parallel_devices,

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
3737
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
3838
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
39+
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
3940
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
4041
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@
3535
from lightning.fabric.strategies import _StrategyRegistry
3636
from lightning.fabric.strategies.deepspeed import (
3737
_DEEPSPEED_AVAILABLE,
38+
_DEEPSPEED_GREATER_EQUAL_0_16,
3839
_format_precision_config,
3940
_validate_checkpoint_directory,
4041
_validate_device_index_selection,
4142
)
43+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
4244
from lightning.fabric.utilities.optimizer import _optimizers_to_device
4345
from lightning.fabric.utilities.seed import reset_seed
4446
from lightning.fabric.utilities.types import _PATH
@@ -262,6 +264,19 @@ def __init__(
262264
" Install it by running `pip install -U deepspeed`."
263265
)
264266

267+
if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
268+
# Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
269+
# DeepSpeed added support for this behavior in version 0.16.0.
270+
import deepspeed
271+
272+
deepspeed_version = deepspeed.__version__
273+
274+
raise ImportError(
275+
f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
276+
f"Detected DeepSpeed version: {deepspeed_version}. "
277+
"Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
278+
)
279+
265280
super().__init__(
266281
accelerator=accelerator,
267282
parallel_devices=parallel_devices,

src/lightning/pytorch/utilities/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(
9393
]
9494
checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
9595
optim_files = get_optim_files(checkpoint_dir)
96-
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
96+
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False)
9797
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
9898
model_file = get_model_state_file(checkpoint_dir, zero_stage)
99-
client_state = torch.load(model_file, map_location=CPU_DEVICE)
99+
client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False)
100100
client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states}
101101
# State dict keys will include reference to wrapper _LightningModuleWrapperBase in old checkpoints created in
102102
# Lightning version < 2.1. Delete the `_forward_module` prefix before saving.

0 commit comments

Comments
 (0)