Skip to content

Commit 758f3b1

Browse files
authored
fix crash in optimizer.step when fsdp2 is enabled and model is bfloat16 (#3905)
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
1 parent cdb2d1f commit 758f3b1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/accelerate/utils/fsdp_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
719719
for name, param in model.named_parameters():
720720
if param.requires_grad and param.dtype != torch.float32:
721721
upcasted_params.append(name)
722-
param.data = param.data.to(torch.float32)
722+
param = param.to(torch.float32)
723723
if accelerator.is_main_process and upcasted_params:
724724
warnings.warn(
725725
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints. "

0 commit comments

Comments
 (0)