@@ -132,6 +132,7 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDP2Model):
132
132
assert torch .equal (ddp_param , shard_param )
133
133
134
134
135
+ @RunIf (min_torch = "2.6.0" )
135
136
@pytest .mark .parametrize ("strategy" , ["fsdp2" , "fsdp2_cpu_offload" ])
136
137
def test_invalid_on_cpu (tmp_path , cuda_count_0 , strategy ):
137
138
"""Test to ensure that we raise Misconfiguration for FSDP on CPU."""
@@ -141,6 +142,7 @@ def test_invalid_on_cpu(tmp_path, cuda_count_0, strategy):
141
142
trainer .strategy .setup_environment ()
142
143
143
144
145
+ @RunIf (min_torch = "2.6.0" )
144
146
def test_custom_mixed_precision ():
145
147
"""Test to ensure that passing a custom mixed precision config works."""
146
148
from torch .distributed .fsdp import MixedPrecisionPolicy
@@ -168,6 +170,7 @@ class InvalidMPPolicy:
168
170
FSDP2Strategy (mp_policy = InvalidMPPolicy ())
169
171
170
172
173
+ @RunIf (min_torch = "2.6.0" )
171
174
@pytest .mark .filterwarnings ("ignore::FutureWarning" )
172
175
@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
173
176
def test_strategy_sync_batchnorm (tmp_path ):
@@ -185,6 +188,7 @@ def test_strategy_sync_batchnorm(tmp_path):
185
188
_run_multiple_stages (trainer , model , os .path .join (tmp_path , "last.ckpt" ))
186
189
187
190
191
+ @RunIf (min_torch = "2.6.0" )
188
192
@pytest .mark .filterwarnings ("ignore::FutureWarning" )
189
193
@RunIf (min_cuda_gpus = 1 , skip_windows = True )
190
194
def test_modules_without_parameters (tmp_path ):
@@ -217,7 +221,7 @@ def training_step(self, batch, batch_idx):
217
221
218
222
219
223
@pytest .mark .filterwarnings ("ignore::FutureWarning" )
220
- @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
224
+ @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True , min_torch = "2.6.0" )
221
225
@pytest .mark .parametrize ("precision" , ["16-mixed" , pytest .param ("bf16-mixed" , marks = RunIf (bf16_cuda = True ))])
222
226
def test_strategy_checkpoint (state_dict_type , precision , tmp_path ):
223
227
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
@@ -237,7 +241,7 @@ def custom_auto_wrap_policy(
237
241
return nonwrapped_numel >= 2
238
242
239
243
240
- @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
244
+ @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True , min_torch = "2.6.0" )
241
245
@pytest .mark .parametrize (
242
246
("precision" , "expected_dtype" ),
243
247
[
@@ -279,6 +283,7 @@ def on_fit_start(self):
279
283
trainer .fit (model )
280
284
281
285
286
+ @RunIf (min_torch = "2.6.0" )
282
287
def test_save_checkpoint_storage_options (tmp_path ):
283
288
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
284
289
strategy = FSDP2Strategy ()
@@ -304,7 +309,7 @@ def on_train_start(self):
304
309
305
310
306
311
@pytest .mark .filterwarnings ("ignore::FutureWarning" )
307
- @RunIf (min_cuda_gpus = 2 , standalone = True )
312
+ @RunIf (min_cuda_gpus = 2 , standalone = True , min_torch = "2.6.0" )
308
313
def test_save_load_sharded_state_dict (tmp_path ):
309
314
"""Test FSDP saving and loading with the sharded state dict format."""
310
315
strategy = FSDP2Strategy ()
@@ -341,7 +346,7 @@ def test_save_load_sharded_state_dict(tmp_path):
341
346
trainer .fit (model , ckpt_path = checkpoint_path )
342
347
343
348
344
- @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
349
+ @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True , min_torch = "2.6.0" )
345
350
@pytest .mark .parametrize (
346
351
("precision" , "expected_dtype" ),
347
352
[
@@ -391,7 +396,7 @@ def _run_setup_assertions(empty_init, expected_device):
391
396
392
397
393
398
@pytest .mark .filterwarnings ("ignore::FutureWarning" )
394
- @RunIf (min_cuda_gpus = 2 , standalone = True , min_torch = "2.3 .0" )
399
+ @RunIf (min_cuda_gpus = 2 , standalone = True , min_torch = "2.6 .0" )
395
400
def test_save_sharded_and_consolidate_and_load (tmp_path ):
396
401
"""Test the consolidation of a FSDP2-sharded checkpoint into a single file."""
397
402
@@ -433,3 +438,11 @@ def configure_optimizers(self):
433
438
max_steps = 4 ,
434
439
)
435
440
trainer .fit (model , ckpt_path = checkpoint_path_full )
441
+
442
+
443
+ @RunIf (max_torch = "2.5" )
444
+ @pytest .mark .parametrize ("strategy" , ["fsdp2" , "fsdp2_cpu_offload" ])
445
+ def test_fsdp2_requires_torch_2_6_or_newer (tmp_path , strategy ):
446
+ """FSDP2 strategies should error on torch < 2.6."""
447
+ with pytest .raises (ValueError , match = "FSDP2Strategy requires torch>=2.6.0." ):
448
+ Trainer (accelerator = "cpu" , default_root_dir = tmp_path , fast_dev_run = True , strategy = strategy )
0 commit comments