diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 93e6048d..afca5599 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -28,6 +28,7 @@ patch_huggingface_clip_grad_norm_fsdp2, patch_huggingface_fsdp2_load_full_state_dict, patch_huggingface_save_and_load_for_dtensors, + patch_prepare_sd_options, patch_torch_optim_foreach_to_not_apply_to_dtensors, prepare_scattermoe, ) @@ -118,6 +119,12 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): + if ( + hasattr(accelerator.state.fsdp_plugin, "fsdp_version") + and accelerator.state.fsdp_plugin.fsdp_version == 2 + ): + # when FSDPv2 is used + patch_prepare_sd_options() if not self._disable_distributed: # - use an internal function call to get the no split diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py index eff545e8..b48e88b6 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py @@ -17,6 +17,7 @@ patch_huggingface_clip_grad_norm_fsdp2, patch_huggingface_fsdp2_load_full_state_dict, patch_huggingface_save_and_load_for_dtensors, + patch_prepare_sd_options, recover_safetensors_from_dcp, ) from .scattermoe_prepare import prepare_scattermoe diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 024fcc16..98b49fa5 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -107,14 +107,15 @@ def save_fsdp_model( def save_fsdp_optimizer( fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0 ): - if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT: raise NotImplementedError( "Checkpointing for megablocks only enabled for sharded state dict." ) - + sd_options = _prepare_sd_options(fsdp_plugin) # get the state dicts for model and optimize - (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) + (model_state_dict, optimizer_state_dict) = get_state_dict( + model, optimizer, options=sd_options + ) # filter out lora state dict # TODO: Once expert layers are supported for LoRA tuning @@ -157,6 +158,28 @@ def save_fsdp_optimizer( logger.info(f"Optimizer state saved in {ckpt_opt}") +def _prepare_sd_options(fsdp_plugin): + sd_options = None + + # we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0 + if fsdp_plugin.fsdp_version == 2: + # pylint: disable=import-outside-toplevel + # Third Party + from torch.distributed.checkpoint.state_dict import StateDictOptions + + sd_options = StateDictOptions( + full_state_dict=fsdp_plugin.state_dict_type + == StateDictType.FULL_STATE_DICT, + cpu_offload=getattr(fsdp_plugin.state_dict_config, "offload_to_cpu", False), + broadcast_from_rank0=getattr( + fsdp_plugin.state_dict_config, "rank0_only", False + ), + flatten_optimizer_state_dict=True, + ) + + return sd_options + + # rewrite of func from accelerate.utils.fsdp_utils.py # - empty function, main logic in load_fsdp_optimizer (see below). def load_fsdp_model( @@ -178,15 +201,16 @@ def load_fsdp_optimizer( optimizer_index=0, adapter_only=False, ): - accelerator.wait_for_everyone() if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT: raise NotImplementedError( "Checkpointing for megablocks only enabled for sharded state dict." ) - + sd_options = _prepare_sd_options(fsdp_plugin) # - get the state dicts - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + model_state_dict, optimizer_state_dict = get_state_dict( + model, optimizer, options=sd_options + ) # - load the model state dict ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") @@ -210,6 +234,7 @@ def load_fsdp_optimizer( optimizer, model_state_dict=model_state_dict, optim_state_dict=optimizer_state_dict, + options=sd_options, ) # FIXME: @@ -246,6 +271,16 @@ def patch_huggingface_save_and_load_for_dtensors(): patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer) +def patch_prepare_sd_options(): + # Third Party + # pylint: disable=import-outside-toplevel + from fms_acceleration.model_patcher import patch_target_module + + patch_target_module( + "accelerate.utils.fsdp_utils._prepare_sd_options", _prepare_sd_options + ) + + # function to monkey patch accelerator clip grad_norm def patch_huggingface_clip_grad_norm_fsdp2(accelerator): accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py index 300cfbf3..26f58178 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py @@ -31,7 +31,12 @@ def calculate_settings(n): pass # import guard added by flim@sg.ibm.com -from transformers.utils.import_utils import _bitsandbytes_available +try: + from transformers.utils.import_utils import _bitsandbytes_available +except ImportError: + from transformers.utils.import_utils import is_bitsandbytes_available + _bitsandbytes_available = is_bitsandbytes_available() + if _bitsandbytes_available: import bitsandbytes as bnb get_ptr = bnb.functional.get_ptr