diff --git a/requirements/fabric/strategies.txt b/requirements/fabric/strategies.txt index 7856db1df2eec..fd6e3063712c4 100644 --- a/requirements/fabric/strategies.txt +++ b/requirements/fabric/strategies.txt @@ -5,5 +5,5 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 -deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict +deepspeed >=0.15.0,<0.17.0; platform_system != "Windows" and platform_system != "Darwin" # strict bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin" diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index 89392d6006d38..03a0a9dd2d947 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -3,4 +3,4 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 -deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict +deepspeed >=0.15.0,<0.17.0; platform_system != "Windows" and platform_system != "Darwin" # strict diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index fe72db20e2b85..e5a7b1fd29ad5 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -37,6 +37,7 @@ from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import _Sharded from lightning.fabric.utilities.distributed import log +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.load import _move_state_into from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed @@ -47,6 +48,7 @@ from torch.optim.lr_scheduler import _LRScheduler _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") +_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0") # TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced. @@ -239,6 +241,19 @@ def __init__( " Install it by running `pip install -U deepspeed`." ) + if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16: + # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints. + # DeepSpeed added support for this behavior in version 0.16.0. + import deepspeed + + deepspeed_version = deepspeed.__version__ + + raise ImportError( + f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. " + f"Detected DeepSpeed version: {deepspeed_version}. " + "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`." + ) + super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 70239baac0e6d..44765aae5a620 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -36,5 +36,6 @@ _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index c5253f77cdedb..50fdc25968bc9 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -35,10 +35,12 @@ from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.deepspeed import ( _DEEPSPEED_AVAILABLE, + _DEEPSPEED_GREATER_EQUAL_0_16, _format_precision_config, _validate_checkpoint_directory, _validate_device_index_selection, ) +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH @@ -262,6 +264,19 @@ def __init__( " Install it by running `pip install -U deepspeed`." ) + if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16: + # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints. + # DeepSpeed added support for this behavior in version 0.16.0. + import deepspeed + + deepspeed_version = deepspeed.__version__ + + raise ImportError( + f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. " + f"Detected DeepSpeed version: {deepspeed_version}. " + "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`." + ) + super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, diff --git a/src/lightning/pytorch/utilities/deepspeed.py b/src/lightning/pytorch/utilities/deepspeed.py index 619e22cac9401..20b418437c681 100644 --- a/src/lightning/pytorch/utilities/deepspeed.py +++ b/src/lightning/pytorch/utilities/deepspeed.py @@ -93,10 +93,10 @@ def convert_zero_checkpoint_to_fp32_state_dict( ] checkpoint_dir = ds_checkpoint_dir(checkpoint_dir) optim_files = get_optim_files(checkpoint_dir) - optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE) + optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False) zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] model_file = get_model_state_file(checkpoint_dir, zero_stage) - client_state = torch.load(model_file, map_location=CPU_DEVICE) + client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False) client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states} # State dict keys will include reference to wrapper _LightningModuleWrapperBase in old checkpoints created in # Lightning version < 2.1. Delete the `_forward_module` prefix before saving.