File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
plugins/accelerated-moe/src/fms_acceleration_moe Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change 1616from typing import Dict , Tuple
1717
1818# Third Party
19+ from accelerate .logging import get_logger
1920from fms_acceleration import AccelerationPlugin
2021from peft import LoraConfig
22+ from torch .distributed .fsdp .fully_sharded_data_parallel import StateDictType
2123from transformers import TrainingArguments
2224import torch
2325
2830 prepare_scattermoe ,
2931)
3032
33+ logger = get_logger (__name__ )
34+
3135
3236# pylint: disable=too-many-instance-attributes
3337class ScatterMoEAccelerationPlugin (AccelerationPlugin ):
@@ -124,6 +128,18 @@ def get_callbacks_and_ready_for_train(
124128 if layer .__class__ .__name__ in _layers
125129 ]
126130
131+ if (
132+ accelerator .state .fsdp_plugin .state_dict_type
133+ != StateDictType .SHARDED_STATE_DICT
134+ ):
135+ accelerator .state .fsdp_plugin .state_dict_type = (
136+ StateDictType .SHARDED_STATE_DICT
137+ )
138+ logger .warning (
139+ "Overriding FSDP plugin state_dict_type to"
140+ f"{ StateDictType .SHARDED_STATE_DICT } ,"
141+ "since the plugin does not support {StateDictType.FULL_STATE_DICT}"
142+ )
127143 # call this to patch the HF save and load functions to be able
128144 # to save DTensors propery
129145 patch_huggingface_save_and_load_for_dtensors ()
You can’t perform that action at this time.
0 commit comments