Skip to content

Commit c3bb0cf

Browse files
authored
Merge pull request #151 from kmehant/nit-patch-sd
refactor: patch sharding state dict and warn
2 parents a59bf06 + 598c280 commit c3bb0cf

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
from typing import Dict, Tuple
1717

1818
# Third Party
19+
from accelerate.logging import get_logger
1920
from fms_acceleration import AccelerationPlugin
2021
from peft import LoraConfig
22+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
2123
from transformers import TrainingArguments
2224
import torch
2325

@@ -28,6 +30,8 @@
2830
prepare_scattermoe,
2931
)
3032

33+
logger = get_logger(__name__)
34+
3135

3236
# pylint: disable=too-many-instance-attributes
3337
class ScatterMoEAccelerationPlugin(AccelerationPlugin):
@@ -124,6 +128,18 @@ def get_callbacks_and_ready_for_train(
124128
if layer.__class__.__name__ in _layers
125129
]
126130

131+
if (
132+
accelerator.state.fsdp_plugin.state_dict_type
133+
!= StateDictType.SHARDED_STATE_DICT
134+
):
135+
accelerator.state.fsdp_plugin.state_dict_type = (
136+
StateDictType.SHARDED_STATE_DICT
137+
)
138+
logger.warning(
139+
"Overriding FSDP plugin state_dict_type to"
140+
f"{StateDictType.SHARDED_STATE_DICT},"
141+
"since the plugin does not support {StateDictType.FULL_STATE_DICT}"
142+
)
127143
# call this to patch the HF save and load functions to be able
128144
# to save DTensors propery
129145
patch_huggingface_save_and_load_for_dtensors()

0 commit comments

Comments
 (0)