diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 123f99cb..2c715280 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -16,8 +16,10 @@ from typing import Dict, Tuple # Third Party +from accelerate.logging import get_logger from fms_acceleration import AccelerationPlugin from peft import LoraConfig +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from transformers import TrainingArguments import torch @@ -28,6 +30,8 @@ prepare_scattermoe, ) +logger = get_logger(__name__) + # pylint: disable=too-many-instance-attributes class ScatterMoEAccelerationPlugin(AccelerationPlugin): @@ -124,6 +128,18 @@ def get_callbacks_and_ready_for_train( if layer.__class__.__name__ in _layers ] + if ( + accelerator.state.fsdp_plugin.state_dict_type + != StateDictType.SHARDED_STATE_DICT + ): + accelerator.state.fsdp_plugin.state_dict_type = ( + StateDictType.SHARDED_STATE_DICT + ) + logger.warning( + "Overriding FSDP plugin state_dict_type to" + f"{StateDictType.SHARDED_STATE_DICT}," + "since the plugin does not support {StateDictType.FULL_STATE_DICT}" + ) # call this to patch the HF save and load functions to be able # to save DTensors propery patch_huggingface_save_and_load_for_dtensors()