Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

# Local
from .utils import (
patch_huggingface_clip_grad_norm_fsdp2,
patch_huggingface_fsdp2_load_full_state_dict,
patch_huggingface_save_and_load_for_dtensors,
patch_torch_optim_foreach_to_not_apply_to_dtensors,
prepare_scattermoe,
Expand Down Expand Up @@ -144,9 +146,23 @@ def get_callbacks_and_ready_for_train(
# to save DTensors propery
patch_huggingface_save_and_load_for_dtensors()

# call this to patch torch optim to not use
# foreach for dtensors
patch_torch_optim_foreach_to_not_apply_to_dtensors()
if (
not hasattr(accelerator.state.fsdp_plugin, "fsdp_version")
or accelerator.state.fsdp_plugin.fsdp_version == 1
):
# call this to patch torch optim to not use
# foreach for dtensors only when fsdpv1 is used
# fsdpv2 with transformers does implicit replication to convert all to dtensors
# before grad norm and optimizer.step() operations
patch_torch_optim_foreach_to_not_apply_to_dtensors()

if (
hasattr(accelerator.state.fsdp_plugin, "fsdp_version")
and accelerator.state.fsdp_plugin.fsdp_version == 2
):
# when EP and FSDPv2 is used
patch_huggingface_clip_grad_norm_fsdp2(accelerator)
patch_huggingface_fsdp2_load_full_state_dict()

return callbacks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

# Local
from .checkpoint_utils import (
patch_huggingface_clip_grad_norm_fsdp2,
patch_huggingface_fsdp2_load_full_state_dict,
patch_huggingface_save_and_load_for_dtensors,
recover_safetensors_from_dcp,
)
Expand Down
Loading