File tree Expand file tree Collapse file tree 6 files changed +35
-4
lines changed Expand file tree Collapse file tree 6 files changed +35
-4
lines changed Original file line number Diff line number Diff line change 55
66# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
77# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
8- deepspeed >=0.14.1,<=0.15 .0; platform_system != "Windows" and platform_system != "Darwin" # strict
8+ deepspeed >=0.15.0,<0.17 .0; platform_system != "Windows" and platform_system != "Darwin" # strict
99bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"
Original file line number Diff line number Diff line change 33
44# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
55# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
6- deepspeed >=0.14.1,<=0.15 .0; platform_system != "Windows" and platform_system != "Darwin" # strict
6+ deepspeed >=0.15.0,<0.17 .0; platform_system != "Windows" and platform_system != "Darwin" # strict
Original file line number Diff line number Diff line change 3737from lightning .fabric .strategies .registry import _StrategyRegistry
3838from lightning .fabric .strategies .strategy import _Sharded
3939from lightning .fabric .utilities .distributed import log
40+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_6
4041from lightning .fabric .utilities .load import _move_state_into
4142from lightning .fabric .utilities .rank_zero import rank_zero_info , rank_zero_warn
4243from lightning .fabric .utilities .seed import reset_seed
4748 from torch .optim .lr_scheduler import _LRScheduler
4849
4950_DEEPSPEED_AVAILABLE = RequirementCache ("deepspeed" )
51+ _DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache ("deepspeed>=0.16.0" )
5052
5153
5254# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
@@ -239,6 +241,19 @@ def __init__(
239241 " Install it by running `pip install -U deepspeed`."
240242 )
241243
244+ if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16 :
245+ # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
246+ # DeepSpeed added support for this behavior in version 0.16.0.
247+ import deepspeed
248+
249+ deepspeed_version = deepspeed .__version__
250+
251+ raise ImportError (
252+ f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
253+ f"Detected DeepSpeed version: { deepspeed_version } . "
254+ "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
255+ )
256+
242257 super ().__init__ (
243258 accelerator = accelerator ,
244259 parallel_devices = parallel_devices ,
Original file line number Diff line number Diff line change 3636_TORCH_GREATER_EQUAL_2_4_1 = compare_version ("torch" , operator .ge , "2.4.1" )
3737_TORCH_GREATER_EQUAL_2_5 = compare_version ("torch" , operator .ge , "2.5.0" )
3838_TORCH_LESS_EQUAL_2_6 = compare_version ("torch" , operator .le , "2.6.0" )
39+ _TORCH_GREATER_EQUAL_2_6 = compare_version ("torch" , operator .ge , "2.6.0" )
3940_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version ("torchmetrics" , operator .ge , "1.0.0" )
4041_PYTHON_GREATER_EQUAL_3_10_0 = (sys .version_info .major , sys .version_info .minor ) >= (3 , 10 )
Original file line number Diff line number Diff line change 3535from lightning .fabric .strategies import _StrategyRegistry
3636from lightning .fabric .strategies .deepspeed import (
3737 _DEEPSPEED_AVAILABLE ,
38+ _DEEPSPEED_GREATER_EQUAL_0_16 ,
3839 _format_precision_config ,
3940 _validate_checkpoint_directory ,
4041 _validate_device_index_selection ,
4142)
43+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_6
4244from lightning .fabric .utilities .optimizer import _optimizers_to_device
4345from lightning .fabric .utilities .seed import reset_seed
4446from lightning .fabric .utilities .types import _PATH
@@ -262,6 +264,19 @@ def __init__(
262264 " Install it by running `pip install -U deepspeed`."
263265 )
264266
267+ if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16 :
268+ # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
269+ # DeepSpeed added support for this behavior in version 0.16.0.
270+ import deepspeed
271+
272+ deepspeed_version = deepspeed .__version__
273+
274+ raise ImportError (
275+ f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
276+ f"Detected DeepSpeed version: { deepspeed_version } . "
277+ "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
278+ )
279+
265280 super ().__init__ (
266281 accelerator = accelerator ,
267282 parallel_devices = parallel_devices ,
Original file line number Diff line number Diff line change @@ -93,10 +93,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(
9393 ]
9494 checkpoint_dir = ds_checkpoint_dir (checkpoint_dir )
9595 optim_files = get_optim_files (checkpoint_dir )
96- optim_state = torch .load (optim_files [0 ], map_location = CPU_DEVICE )
96+ optim_state = torch .load (optim_files [0 ], map_location = CPU_DEVICE , weights_only = False )
9797 zero_stage = optim_state ["optimizer_state_dict" ]["zero_stage" ]
9898 model_file = get_model_state_file (checkpoint_dir , zero_stage )
99- client_state = torch .load (model_file , map_location = CPU_DEVICE )
99+ client_state = torch .load (model_file , map_location = CPU_DEVICE , weights_only = False )
100100 client_state = {key : value for key , value in client_state .items () if key not in deepspeed_states }
101101 # State dict keys will include reference to wrapper _LightningModuleWrapperBase in old checkpoints created in
102102 # Lightning version < 2.1. Delete the `_forward_module` prefix before saving.
You can’t perform that action at this time.
0 commit comments