File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -348,7 +348,7 @@ def _test_prepare_fsdp2_shard_all() -> None:
348348
349349 module = SimpleModule ()
350350 device = torch .device ("cuda" )
351- strategy = FSDP2Strategy (modules_to_shard = "all" )
351+ strategy = FSDP2Strategy (modules_to_shard = "all" , mp_policy = torch . bfloat16 )
352352 prepare_fsdp2 (module , device , strategy )
353353
354354 for submodule in module .modules ():
Original file line number Diff line number Diff line change @@ -371,9 +371,9 @@ def prepare_fsdp2(
371371 fsdp_kwargs ["offload_policy" ] = CPUOffloadPolicy ()
372372 if (mp_policy := strategy .mp_policy ) is not None :
373373 if isinstance (mp_policy , MixedPrecisionPolicy ):
374- fsdp_kwargs ["mixed_precision " ] = mp_policy
374+ fsdp_kwargs ["mp_policy " ] = mp_policy
375375 else :
376- fsdp_kwargs ["mixed_precision " ] = MixedPrecisionPolicy (
376+ fsdp_kwargs ["mp_policy " ] = MixedPrecisionPolicy (
377377 param_dtype = mp_policy ,
378378 reduce_dtype = mp_policy ,
379379 output_dtype = mp_policy ,
You can’t perform that action at this time.
0 commit comments