Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,13 +707,17 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
# Set it to None if it doesn't exist and do the upcast always
model_dtype = getattr(model, "dtype", None)
if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
# We upcast the model according to `deepspeed`'s implementation
# We upcast the trainable parameters according to `deepspeed`'s implementation
# More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
model = model.to(torch.float32)
if accelerator.is_main_process:
# TODO(siro1): Add a warning for each parameter that was upcasted
upcasted_params = []
for name, param in model.named_parameters():
if param.requires_grad and param.dtype != torch.float32:
upcasted_params.append(name)
param.data = param.data.to(torch.float32)
if accelerator.is_main_process and upcasted_params:
warnings.warn(
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints."
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints. "
f"This effects {len(upcasted_params)} parameters: {upcasted_params}..."
)
return model

Expand Down
Loading