diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 2c715280..93e6048d 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -25,6 +25,8 @@ # Local from .utils import ( + patch_huggingface_clip_grad_norm_fsdp2, + patch_huggingface_fsdp2_load_full_state_dict, patch_huggingface_save_and_load_for_dtensors, patch_torch_optim_foreach_to_not_apply_to_dtensors, prepare_scattermoe, @@ -144,9 +146,23 @@ def get_callbacks_and_ready_for_train( # to save DTensors propery patch_huggingface_save_and_load_for_dtensors() - # call this to patch torch optim to not use - # foreach for dtensors - patch_torch_optim_foreach_to_not_apply_to_dtensors() + if ( + not hasattr(accelerator.state.fsdp_plugin, "fsdp_version") + or accelerator.state.fsdp_plugin.fsdp_version == 1 + ): + # call this to patch torch optim to not use + # foreach for dtensors only when fsdpv1 is used + # fsdpv2 with transformers does implicit replication to convert all to dtensors + # before grad norm and optimizer.step() operations + patch_torch_optim_foreach_to_not_apply_to_dtensors() + + if ( + hasattr(accelerator.state.fsdp_plugin, "fsdp_version") + and accelerator.state.fsdp_plugin.fsdp_version == 2 + ): + # when EP and FSDPv2 is used + patch_huggingface_clip_grad_norm_fsdp2(accelerator) + patch_huggingface_fsdp2_load_full_state_dict() return callbacks diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py index a3f2117d..eff545e8 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py @@ -14,6 +14,8 @@ # Local from .checkpoint_utils import ( + patch_huggingface_clip_grad_norm_fsdp2, + patch_huggingface_fsdp2_load_full_state_dict, patch_huggingface_save_and_load_for_dtensors, recover_safetensors_from_dcp, ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 3d173c3b..3269e1f2 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -16,9 +16,11 @@ from collections import defaultdict from typing import Dict, List, Union import json +import math import os import re import shutil +import types # Third Party from accelerate.logging import get_logger @@ -31,6 +33,7 @@ ) from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from torch.distributed.tensor import DTensor from transformers import PretrainedConfig from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME import torch @@ -39,6 +42,7 @@ # Local from .scattermoe_constants import ( FILE_SAFETENSOR_INDEX, + KEY_EXPERT_PARALLEL, PARAM_NAME_ROUTER_SCATTERMOE, PARAM_NAME_WEIGHT_SCATTERMOE, get_scattermoe_conv_spec_from_archs, @@ -242,6 +246,25 @@ def patch_huggingface_save_and_load_for_dtensors(): patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer) +# function to monkey patch accelerator clip grad_norm +def patch_huggingface_clip_grad_norm_fsdp2(accelerator): + accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator) + + +def patch_huggingface_fsdp2_load_full_state_dict(): + # Third Party + # pylint: disable=import-outside-toplevel + from fms_acceleration.model_patcher import patch_target_module + + patch_target_module( + "accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict", + fsdp2_load_full_state_dict, + ) + patch_target_module( + "accelerate.utils.fsdp_utils.fsdp2_prepare_model", fsdp2_prepare_model + ) + + # this function implements a trick to get the resolved cache file to acccess the safetensor # - NOTE: does not work if _dict_from_json_file is not called, such as in the case of GGUF files. def get_resolved_checkpoint_location(model_name_or_path: str): @@ -613,6 +636,50 @@ def recover_safetensors_from_dcp( ) +def clip_grad_norm_(self, parameters, max_norm, norm_type=2): + """grad norm patch when EP is enabled""" + # code inspired from + # https://github.com/pytorch/torchtitan/blob/72b16b13abc88ba08f3e1796e5caee09abd94554/torchtitan/distributed/utils.py#L398 + ep_params = [] + non_ep_params = [] + ep_grads = [] + non_ep_grads = [] + + for p in parameters: + if p.grad is None: + continue + if ( + p.device_mesh.mesh_dim_names + and KEY_EXPERT_PARALLEL in p.device_mesh.mesh_dim_names + ): + ep_params.append(p) + ep_grads.append(p.grad) + else: + non_ep_params.append(p) + non_ep_grads.append(p.grad) + ep_grads_total_norm = torch.nn.utils.get_total_norm( + ep_grads, norm_type, False, True + ) + + if isinstance(ep_grads_total_norm, DTensor): + ep_grads_total_norm = ep_grads_total_norm.full_tensor() + + non_ep_grads_total_norm = torch.nn.utils.get_total_norm( + non_ep_grads, norm_type, False, True + ).full_tensor() + + if math.isinf(norm_type): + total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) + else: + total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + total_norm **= 1.0 / norm_type + + torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, True) + torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, True) + + return total_norm + + # have it serve as a conversion script if __name__ == "__main__": # Standard @@ -651,3 +718,264 @@ def recover_safetensors_from_dcp( recover_safetensors_from_dcp( args.checkpoint_dir, args.pretrained_model_name_or_path, args.output_dir ) + + +# code taken from HF accelerate and modified +def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. + This is done by broadcasting the parameters from rank 0 to all other ranks. + This function modifies the model in-place. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): + The model to load the state dict into, expected to be on meta device + or a VRAM spike can occur + full_sd (`dict`): The full state dict to load, can only be on rank 0 + """ + # Third Party + # pylint: disable=import-outside-toplevel + from torch.distributed.tensor import distribute_tensor + import torch.distributed as dist + + # Model was previously copied to meta device + meta_sharded_sd = model.state_dict() + sharded_sd = {} + + # Rank 0 distributes the full state dict to other ranks + def _infer_parameter_dtype(model, param_name, empty_param): + try: + old_param = model.get_parameter_or_buffer(param_name) + except AttributeError: + # Need this for LORA, as there some params are not *parameters* of sorts + base_param_name, local_param_name = param_name.rsplit(".", 1) + submodule = model.get_submodule(base_param_name) + old_param = getattr(submodule, local_param_name) + + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + casting_dtype = None + is_param_float8_e4m3fn = ( + is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + ) + + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + casting_dtype = old_param.dtype + + return old_param is not None and old_param.is_contiguous(), casting_dtype + + def _cast_and_contiguous(tensor, to_contiguous, dtype): + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if to_contiguous: + tensor = tensor.contiguous() + return tensor + + ignored_params = { + p.detach() + # pylint: disable=undefined-variable + for p in get_parameters_from_modules( + accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device + ) + } + if accelerator.is_main_process: + for (param_name, full_param), sharded_param in zip( + full_sd.items(), meta_sharded_sd.values() + ): + # ignored params will not be on meta device + # and not handled by FSDP + if sharded_param.device != torch.device("meta"): + sharded_sd[param_name] = sharded_param + else: + device_mesh = sharded_param.device_mesh + full_param = full_param.detach().to(device_mesh.device_type) + dist.broadcast(full_param, src=0, group=dist.group.WORLD) + sharded_tensor = distribute_tensor( + full_param, device_mesh, sharded_param.placements + ) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_param, + ) + sharded_tensor = _cast_and_contiguous( + sharded_tensor, to_contiguous, casting_dtype + ) + sharded_sd[param_name] = sharded_tensor + # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock + else: + for param_name, sharded_param in meta_sharded_sd.items(): + # ignored params will not be on meta device + # and not handled by FSDP + if sharded_param.device != torch.device("meta"): + sharded_sd[param_name] = sharded_param + else: + device_mesh = sharded_param.device_mesh + full_tensor = torch.empty( + sharded_param.size(), + device=device_mesh.device_type, + dtype=sharded_param.dtype, + ) + dist.broadcast(full_tensor, src=0, group=dist.group.WORLD) + sharded_tensor = distribute_tensor( + full_tensor, device_mesh, sharded_param.placements + ) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_tensor, + ) + sharded_tensor = _cast_and_contiguous( + sharded_tensor, to_contiguous, casting_dtype + ) + sharded_sd[param_name] = sharded_tensor + + # we set `assign=True` because our params are on meta device + model.load_state_dict(sharded_sd, assign=True) + return model + + +# code taken from HF accelerate and modified +def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: + """Prepares the model for FSDP2 in-place. Also returns the model to avoid + misuse of the original model. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): The model to prepare + + Returns: + `torch.nn.Module`: Prepared model + """ + # Third Party + # pylint: disable=import-outside-toplevel + from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard + + is_type_fsdp = isinstance(model, FSDPModule) or ( + # pylint: disable=undefined-variable + is_compiled_module(model) + and isinstance(model._orig_mod, FSDPModule) + ) + if is_type_fsdp: + return model + + fsdp2_plugin = accelerator.state.fsdp_plugin + + fsdp2_plugin.set_auto_wrap_policy(model) + + original_sd = model.state_dict() + mesh = getattr(accelerator, "torch_device_mesh", None) + + fsdp2_kwargs = { + "reshard_after_forward": fsdp2_plugin.reshard_after_forward, + "offload_policy": fsdp2_plugin.cpu_offload, + # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` + "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), + "mesh": ( + mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] + if mesh is not None + else None + ), + # pylint: disable=undefined-variable + "ignored_params": get_parameters_from_modules( + fsdp2_plugin.ignored_modules, model, accelerator.device + ), + } + + model_has_params4bit = False + for _, param in model.named_parameters(): + # this is a temporary fix whereby loading models with bnb params + # cannot be moved from GPU to a meta device due with FSDP2 because + # torch operations don't return the original class type bypassing the + # move to meta will still cause the VRAM spike, but at least it still will load + if param.__class__.__name__ == "Params4bit": + model_has_params4bit = True + break + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # pylint: disable=undefined-variable + non_persistent_buffer_fqns = get_non_persistent_buffers( + model, recurse=True, fqns=True + ) + # pylint: disable=undefined-variable + original_non_persistent_buffers = copy.deepcopy( + {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns} + ) + # We move the model parameters to meta device that are managed by FSDPv2, + # as then sharding happens on meta device + with torch.no_grad(): + for _, module in model.named_modules(): + for param_name, param in list(module.named_parameters(recurse=False)): + if param not in fsdp2_kwargs["ignored_params"]: + # Create new parameter on meta device + meta_param = torch.nn.Parameter( + torch.empty(param.shape, dtype=param.dtype, device="meta"), + requires_grad=param.requires_grad, + ) + setattr(module, param_name, meta_param) + # model = model.to(torch.device("meta")) + # We need to re-tie the weights, not exactly sure why, but if we don't do this, + # reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage + # We assume `transformers` models have a `tie_weights` method if they support it + if hasattr(model, "tie_weights"): + model.tie_weights() + + # pylint: disable=undefined-variable + auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) + if auto_wrap_policy_func is not None: + # We skip the model itself, as that one is always wrapped + # pylint: disable=undefined-variable + for module in get_module_children_bottom_up(model)[:-1]: + if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule): + fully_shard(module, **fsdp2_kwargs) + + if not isinstance(model, FSDPModule): + fully_shard(model, **fsdp2_kwargs) + + if fsdp2_plugin.cpu_ram_efficient_loading: + # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights + # Other ranks have an empty model on `meta` device, so we need to distribute + # the weights properly + fsdp2_load_full_state_dict(accelerator, model, original_sd) + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # We re-register the buffers, as they may not be in the state_dict + for fqn, buffer_tensor in original_non_persistent_buffers.items(): + buffer_tensor = buffer_tensor.to(accelerator.device) + + if "." in fqn: + parent_fqn, local_buffer_name = fqn.rsplit(".", 1) + parent_module = model.get_submodule(parent_fqn) + else: + local_buffer_name = fqn + parent_module = model + + parent_module.register_buffer( + local_buffer_name, buffer_tensor, persistent=False + ) + + # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie + # Needs to be called both here and above + # removing this call makes the have slightly different loss + # removing the call above leads to extra memory usage as explained in the comment above + if hasattr(model, "tie_weights"): + model.tie_weights() + + # There is no `dtype` attribution for nn.Module + # Set it to None if it doesn't exist and do the upcast always + model_dtype = getattr(model, "dtype", None) + if accelerator.mixed_precision != "no" and ( + model_dtype is None or model_dtype != torch.float32 + ): + # We upcast the model according to `deepspeed`'s implementation + # More info about this can be found in `accelerator.py:prepare_model`s + # FSDP1 section + model = model.to(torch.float32) + if accelerator.is_main_process: + # TODO(siro1): Add a warning for each parameter that was upcasted + # pylint: disable=undefined-variable + warnings.warn( + "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no')" + "may affect the precision of model checkpoints." + ) + return model