Skip to content

Commit b3ce371

Browse files
committed
could it be
1 parent 9389669 commit b3ce371

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
3737
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
38+
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
3839
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
3940
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
4041
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/pytorch/strategies/fsdp2.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
_sync_ddp_if_available,
4949
)
5050
from lightning.fabric.utilities.distributed import group as _group
51-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
51+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
5252
from lightning.fabric.utilities.init import _has_all_dtensor_params_or_buffers, _has_meta_device_parameters_or_buffers
5353
from lightning.fabric.utilities.optimizer import _optimizers_to_device
5454
from lightning.fabric.utilities.seed import reset_seed
@@ -66,9 +66,9 @@
6666
from torch.distributed.device_mesh import DeviceMesh
6767
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
6868

69-
try:
69+
if _TORCH_GREATER_EQUAL_2_6:
7070
from torch.distributed.checkpoint.stateful import Stateful as _TorchStateful
71-
except ImportError:
71+
else:
7272

7373
class _TorchStateful: # type: ignore[no-redef]
7474
pass
@@ -131,6 +131,11 @@ def __init__(
131131
mp_policy: Optional["MixedPrecisionPolicy"] = None,
132132
**kwargs: Any,
133133
) -> None:
134+
if not _TORCH_GREATER_EQUAL_2_6:
135+
raise ModuleNotFoundError(
136+
"FSDP2Strategy requires torch>=2.6.0. "
137+
f"Found torch {torch.__version__}. Please upgrade torch to use FSDP2Strategy."
138+
)
134139
super().__init__(
135140
accelerator=accelerator,
136141
parallel_devices=parallel_devices,
@@ -206,7 +211,7 @@ def setup_environment(self) -> None:
206211
self._process_group_backend = self._get_process_group_backend()
207212
assert self.cluster_environment is not None
208213
kwargs: dict[str, Any] = {"timeout": self._timeout}
209-
if _TORCH_GREATER_EQUAL_2_3:
214+
if _TORCH_GREATER_EQUAL_2_6:
210215
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
211216
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
212217

@@ -551,6 +556,11 @@ class AppState(_TorchStateful):
551556
"""
552557

553558
def __init__(self, model: Module, optimizers: list[Optimizer]) -> None:
559+
if not _TORCH_GREATER_EQUAL_2_6:
560+
raise ModuleNotFoundError(
561+
"AppState requires torch>=2.6.0. "
562+
f"Found torch {torch.__version__}. Please upgrade torch to use AppState."
563+
)
554564
self.model = model
555565
self.optimizers = optimizers
556566

tests/tests_pytorch/strategies/test_fsdp2.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDP2Model):
132132
assert torch.equal(ddp_param, shard_param)
133133

134134

135+
@RunIf(min_torch="2.6.0")
135136
@pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"])
136137
def test_invalid_on_cpu(tmp_path, cuda_count_0, strategy):
137138
"""Test to ensure that we raise Misconfiguration for FSDP on CPU."""
@@ -141,6 +142,7 @@ def test_invalid_on_cpu(tmp_path, cuda_count_0, strategy):
141142
trainer.strategy.setup_environment()
142143

143144

145+
@RunIf(min_torch="2.6.0")
144146
def test_custom_mixed_precision():
145147
"""Test to ensure that passing a custom mixed precision config works."""
146148
from torch.distributed.fsdp import MixedPrecisionPolicy
@@ -168,6 +170,7 @@ class InvalidMPPolicy:
168170
FSDP2Strategy(mp_policy=InvalidMPPolicy())
169171

170172

173+
@RunIf(min_torch="2.6.0")
171174
@pytest.mark.filterwarnings("ignore::FutureWarning")
172175
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
173176
def test_strategy_sync_batchnorm(tmp_path):
@@ -185,6 +188,7 @@ def test_strategy_sync_batchnorm(tmp_path):
185188
_run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt"))
186189

187190

191+
@RunIf(min_torch="2.6.0")
188192
@pytest.mark.filterwarnings("ignore::FutureWarning")
189193
@RunIf(min_cuda_gpus=1, skip_windows=True)
190194
def test_modules_without_parameters(tmp_path):
@@ -217,7 +221,7 @@ def training_step(self, batch, batch_idx):
217221

218222

219223
@pytest.mark.filterwarnings("ignore::FutureWarning")
220-
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
224+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0")
221225
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
222226
def test_strategy_checkpoint(state_dict_type, precision, tmp_path):
223227
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
@@ -237,7 +241,7 @@ def custom_auto_wrap_policy(
237241
return nonwrapped_numel >= 2
238242

239243

240-
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
244+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0")
241245
@pytest.mark.parametrize(
242246
("precision", "expected_dtype"),
243247
[
@@ -279,6 +283,7 @@ def on_fit_start(self):
279283
trainer.fit(model)
280284

281285

286+
@RunIf(min_torch="2.6.0")
282287
def test_save_checkpoint_storage_options(tmp_path):
283288
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
284289
strategy = FSDP2Strategy()
@@ -304,7 +309,7 @@ def on_train_start(self):
304309

305310

306311
@pytest.mark.filterwarnings("ignore::FutureWarning")
307-
@RunIf(min_cuda_gpus=2, standalone=True)
312+
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.6.0")
308313
def test_save_load_sharded_state_dict(tmp_path):
309314
"""Test FSDP saving and loading with the sharded state dict format."""
310315
strategy = FSDP2Strategy()
@@ -341,7 +346,7 @@ def test_save_load_sharded_state_dict(tmp_path):
341346
trainer.fit(model, ckpt_path=checkpoint_path)
342347

343348

344-
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
349+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.6.0")
345350
@pytest.mark.parametrize(
346351
("precision", "expected_dtype"),
347352
[
@@ -391,7 +396,7 @@ def _run_setup_assertions(empty_init, expected_device):
391396

392397

393398
@pytest.mark.filterwarnings("ignore::FutureWarning")
394-
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
399+
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.6.0")
395400
def test_save_sharded_and_consolidate_and_load(tmp_path):
396401
"""Test the consolidation of a FSDP2-sharded checkpoint into a single file."""
397402

@@ -433,3 +438,11 @@ def configure_optimizers(self):
433438
max_steps=4,
434439
)
435440
trainer.fit(model, ckpt_path=checkpoint_path_full)
441+
442+
443+
@RunIf(max_torch="2.5")
444+
@pytest.mark.parametrize("strategy", ["fsdp2", "fsdp2_cpu_offload"])
445+
def test_fsdp2_requires_torch_2_6_or_newer(tmp_path, strategy):
446+
"""FSDP2 strategies should error on torch < 2.6."""
447+
with pytest.raises(ValueError, match="FSDP2Strategy requires torch>=2.6.0."):
448+
Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy=strategy)

0 commit comments

Comments
 (0)