11import os
2- from typing import Any , Dict , Optional
2+ from functools import partial
3+ from typing import Any , Callable , Dict , Optional
34from unittest import mock
45from unittest .mock import ANY , Mock
56
1819
1920if _TORCH_GREATER_EQUAL_1_12 :
2021 from torch .distributed .fsdp .fully_sharded_data_parallel import CPUOffload , FullyShardedDataParallel , MixedPrecision
21- from torch .distributed .fsdp .wrap import wrap
22+ from torch .distributed .fsdp .wrap import size_based_auto_wrap_policy , wrap
23+ else :
24+ size_based_auto_wrap_policy = object
25+
26+ if _TORCH_GREATER_EQUAL_2_0 :
27+ from torch .distributed .fsdp .wrap import _FSDPPolicy
28+ else :
29+ _FSDPPolicy = object
2230
2331
2432class TestFSDPModel (BoringModel ):
@@ -117,17 +125,18 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
117125
118126 _assert_save_equality (trainer , model_path , cls = model .__class__ )
119127
120- # Test entry point
121- trainer .test (model ) # model is wrapped, will not call `configure_sharded_model`
128+ with torch .inference_mode ():
129+ # Test entry point
130+ trainer .test (model ) # model is wrapped, will not call `configure_sharded_model`
122131
123- # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
124- trainer .test (ckpt_path = model_path )
132+ # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
133+ trainer .test (ckpt_path = model_path )
125134
126- # Predict entry point
127- trainer .predict (model ) # model is wrapped, will not call `configure_sharded_model`
135+ # Predict entry point
136+ trainer .predict (model ) # model is wrapped, will not call `configure_sharded_model`
128137
129- # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
130- trainer .predict (ckpt_path = model_path )
138+ # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
139+ trainer .predict (ckpt_path = model_path )
131140
132141
133142def _assert_save_equality (trainer , ckpt_path , cls = TestFSDPModel ):
@@ -200,6 +209,20 @@ def test_fsdp_strategy_checkpoint(tmpdir, precision):
200209 _run_multiple_stages (trainer , model , os .path .join (tmpdir , "last.ckpt" ))
201210
202211
212+ class CustomWrapPolicy (_FSDPPolicy ):
213+ """This is a wrapper around :func:`_module_wrap_policy`."""
214+
215+ def __init__ (self , min_num_params : int ):
216+ self ._policy : Callable = partial (size_based_auto_wrap_policy , min_num_params = min_num_params )
217+
218+ @property
219+ def policy (self ):
220+ return self ._policy
221+
222+
223+ custom_fsdp_policy = CustomWrapPolicy (min_num_params = 2 )
224+
225+
203226if _TORCH_GREATER_EQUAL_2_0 :
204227
205228 def custom_auto_wrap_policy (
@@ -221,19 +244,40 @@ def custom_auto_wrap_policy(
221244
222245@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True , min_torch = "1.12" )
223246@pytest .mark .parametrize (
224- "model, strategy" ,
247+ "model, strategy, strategy_cfg " ,
225248 [
226- (TestFSDPModel (), "fsdp" ),
227- (TestFSDPModelAutoWrapped (), FSDPStrategy ),
249+ pytest .param (TestFSDPModel (), "fsdp" , None , id = "manually_wrapped" ),
250+ pytest .param (
251+ TestFSDPModelAutoWrapped (),
252+ FSDPStrategy ,
253+ {"auto_wrap_policy" : custom_auto_wrap_policy },
254+ marks = RunIf (max_torch = "2.0.0" ),
255+ id = "autowrap_1x" ,
256+ ),
257+ pytest .param (
258+ TestFSDPModelAutoWrapped (),
259+ FSDPStrategy ,
260+ {"auto_wrap_policy" : custom_auto_wrap_policy },
261+ marks = RunIf (min_torch = "2.0.0" ),
262+ id = "autowrap_2x" ,
263+ ),
264+ pytest .param (
265+ TestFSDPModelAutoWrapped (),
266+ FSDPStrategy ,
267+ {"auto_wrap_policy" : custom_fsdp_policy , "use_orig_params" : True },
268+ marks = RunIf (min_torch = "2.0.0" ),
269+ id = "autowrap_use_orig_params" ,
270+ ),
228271 ],
229272)
230- def test_fsdp_checkpoint_multi_gpus (tmpdir , model , strategy ):
273+ def test_fsdp_checkpoint_multi_gpus (tmpdir , model , strategy , strategy_cfg ):
231274 """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
232275
233276 ck = ModelCheckpoint (save_last = True )
234277
278+ strategy_cfg = strategy_cfg or {}
235279 if not isinstance (strategy , str ):
236- strategy = strategy (auto_wrap_policy = custom_auto_wrap_policy )
280+ strategy = strategy (** strategy_cfg )
237281
238282 trainer = Trainer (
239283 default_root_dir = tmpdir ,
0 commit comments