2929 get_launch_command ,
3030 path_in_accelerate_package ,
3131 require_fp16 ,
32+ require_fsdp2 ,
3233 require_multi_device ,
3334 require_non_cpu ,
3435 require_non_torch_xla ,
3738)
3839from accelerate .utils import is_bf16_available , is_fp16_available , is_hpu_available , patch_environment , set_seed
3940from accelerate .utils .constants import (
40- FSDP2_PYTORCH_VERSION ,
4141 FSDP2_STATE_DICT_TYPE ,
4242 FSDP_AUTO_WRAP_POLICY ,
4343 FSDP_BACKWARD_PREFETCH ,
4646)
4747from accelerate .utils .dataclasses import FullyShardedDataParallelPlugin
4848from accelerate .utils .fsdp_utils import disable_fsdp_ram_efficient_loading , enable_fsdp_ram_efficient_loading
49- from accelerate .utils .versions import is_torch_version
5049
5150
5251set_seed (42 )
6362if is_bf16_available ():
6463 dtypes .append (BF16 )
6564
66- FSDP_VERSIONS = [1 ]
67- if is_torch_version (">=" , FSDP2_PYTORCH_VERSION ):
68- FSDP_VERSIONS .append (2 )
69-
7065
7166@require_non_cpu
7267@require_non_torch_xla
@@ -90,24 +85,7 @@ def setUp(self):
9085 2 : self .fsdp2_env ,
9186 }
9287
93- def run (self , result = None ):
94- """Override run to get the current test name and format failures to include FSDP version."""
95- test_method = getattr (self , self ._testMethodName )
96- orig_test_method = test_method
97-
98- def test_wrapper (* args , ** kwargs ):
99- for fsdp_version in FSDP_VERSIONS :
100- try :
101- self .current_fsdp_version = fsdp_version
102- return orig_test_method (* args , ** kwargs )
103- except Exception as e :
104- raise type (e )(f"FSDP version { fsdp_version } : { str (e )} " ) from e
105-
106- setattr (self , self ._testMethodName , test_wrapper )
107- try :
108- return super ().run (result )
109- finally :
110- setattr (self , self ._testMethodName , orig_test_method )
88+ self .current_fsdp_version = 1
11189
11290 def test_sharding_strategy (self ):
11391 from torch .distributed .fsdp .fully_sharded_data_parallel import ShardingStrategy
@@ -421,6 +399,15 @@ def test_cpu_ram_efficient_loading(self):
421399 assert os .environ .get ("FSDP_CPU_RAM_EFFICIENT_LOADING" ) == "False"
422400
423401
402+ @require_fsdp2
403+ @require_non_cpu
404+ @require_non_torch_xla
405+ class FSDP2PluginIntegration (FSDPPluginIntegration ):
406+ def setUp (self ):
407+ super ().setUp ()
408+ self .current_fsdp_version = 2
409+
410+
424411@run_first
425412# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
426413@require_non_torch_xla
@@ -462,24 +449,7 @@ def setUp(self):
462449 self .n_train = 160
463450 self .n_val = 160
464451
465- def run (self , result = None ):
466- """Override run to get the current test name and format failures to include FSDP version."""
467- test_method = getattr (self , self ._testMethodName )
468- orig_test_method = test_method
469-
470- def test_wrapper (* args , ** kwargs ):
471- for fsdp_version in FSDP_VERSIONS :
472- try :
473- self .current_fsdp_version = fsdp_version
474- return orig_test_method (* args , ** kwargs )
475- except Exception as e :
476- raise type (e )(f"FSDP version { fsdp_version } : { str (e )} " ) from e
477-
478- setattr (self , self ._testMethodName , test_wrapper )
479- try :
480- return super ().run (result )
481- finally :
482- setattr (self , self ._testMethodName , orig_test_method )
452+ self .current_fsdp_version = 1
483453
484454 @require_fp16
485455 def test_performance (self ):
@@ -633,3 +603,15 @@ def test_peak_memory_usage(self):
633603 )
634604 with patch_environment (omp_num_threads = 1 ):
635605 execute_subprocess_async (cmd_config )
606+
607+
608+ @require_fsdp2
609+ @run_first
610+ # Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
611+ @require_non_torch_xla
612+ @require_multi_device
613+ @slow
614+ class FSDP2IntegrationTest (FSDPIntegrationTest ):
615+ def setUp (self ):
616+ super ().setUp ()
617+ self .current_fsdp_version = 2
0 commit comments