Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/fabric/strategies.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@

# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict
deepspeed >=0.15.0,<0.17.0; platform_system != "Windows" and platform_system != "Darwin" # strict
bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"
2 changes: 1 addition & 1 deletion requirements/pytorch/strategies.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict
deepspeed >=0.15.0,<0.17.0; platform_system != "Windows" and platform_system != "Darwin" # strict
15 changes: 15 additions & 0 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
from lightning.fabric.utilities.load import _move_state_into
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
Expand All @@ -47,6 +48,7 @@
from torch.optim.lr_scheduler import _LRScheduler

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")


# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
Expand Down Expand Up @@ -239,6 +241,19 @@ def __init__(
" Install it by running `pip install -U deepspeed`."
)

if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
# Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
# DeepSpeed added support for this behavior in version 0.16.0.
import deepspeed

deepspeed_version = deepspeed.__version__

raise ImportError(
f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
f"Detected DeepSpeed version: {deepspeed_version}. "
"Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
)

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
15 changes: 15 additions & 0 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.deepspeed import (
_DEEPSPEED_AVAILABLE,
_DEEPSPEED_GREATER_EQUAL_0_16,
_format_precision_config,
_validate_checkpoint_directory,
_validate_device_index_selection,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH
Expand Down Expand Up @@ -262,6 +264,19 @@ def __init__(
" Install it by running `pip install -U deepspeed`."
)

if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
# Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
# DeepSpeed added support for this behavior in version 0.16.0.
import deepspeed

deepspeed_version = deepspeed.__version__

raise ImportError(
f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
f"Detected DeepSpeed version: {deepspeed_version}. "
"Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
)

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/utilities/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(
]
checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
optim_files = get_optim_files(checkpoint_dir)
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False)
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
model_file = get_model_state_file(checkpoint_dir, zero_stage)
client_state = torch.load(model_file, map_location=CPU_DEVICE)
client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False)
client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states}
# State dict keys will include reference to wrapper _LightningModuleWrapperBase in old checkpoints created in
# Lightning version < 2.1. Delete the `_forward_module` prefix before saving.
Expand Down
Loading