Skip to content

Commit f06964d

Browse files
authored
feat: FSDP2 with MoE kernels and expert parallel (#157)
* fsdp2 patches Signed-off-by: Mehant Kammakomati <[email protected]> * fsdp2 patches Signed-off-by: Mehant Kammakomati <[email protected]> * fsdp2 patches Signed-off-by: Mehant Kammakomati <[email protected]> * fsdp2 patches Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 4993069 commit f06964d

File tree

3 files changed

+349
-3
lines changed

3 files changed

+349
-3
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
# Local
2727
from .utils import (
28+
patch_huggingface_clip_grad_norm_fsdp2,
29+
patch_huggingface_fsdp2_load_full_state_dict,
2830
patch_huggingface_save_and_load_for_dtensors,
2931
patch_torch_optim_foreach_to_not_apply_to_dtensors,
3032
prepare_scattermoe,
@@ -144,9 +146,23 @@ def get_callbacks_and_ready_for_train(
144146
# to save DTensors propery
145147
patch_huggingface_save_and_load_for_dtensors()
146148

147-
# call this to patch torch optim to not use
148-
# foreach for dtensors
149-
patch_torch_optim_foreach_to_not_apply_to_dtensors()
149+
if (
150+
not hasattr(accelerator.state.fsdp_plugin, "fsdp_version")
151+
or accelerator.state.fsdp_plugin.fsdp_version == 1
152+
):
153+
# call this to patch torch optim to not use
154+
# foreach for dtensors only when fsdpv1 is used
155+
# fsdpv2 with transformers does implicit replication to convert all to dtensors
156+
# before grad norm and optimizer.step() operations
157+
patch_torch_optim_foreach_to_not_apply_to_dtensors()
158+
159+
if (
160+
hasattr(accelerator.state.fsdp_plugin, "fsdp_version")
161+
and accelerator.state.fsdp_plugin.fsdp_version == 2
162+
):
163+
# when EP and FSDPv2 is used
164+
patch_huggingface_clip_grad_norm_fsdp2(accelerator)
165+
patch_huggingface_fsdp2_load_full_state_dict()
150166

151167
return callbacks
152168

plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
# Local
1616
from .checkpoint_utils import (
17+
patch_huggingface_clip_grad_norm_fsdp2,
18+
patch_huggingface_fsdp2_load_full_state_dict,
1719
patch_huggingface_save_and_load_for_dtensors,
1820
recover_safetensors_from_dcp,
1921
)

0 commit comments

Comments
 (0)