From 51f1e3ba9bb81d0371dbbbade9d4f98a431a7474 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 15 Oct 2025 02:04:19 +0530 Subject: [PATCH 1/4] fsdp2 patches Signed-off-by: Mehant Kammakomati --- .../framework_plugin_scattermoe.py | 16 +- .../fms_acceleration_moe/utils/__init__.py | 2 + .../utils/checkpoint_utils.py | 278 ++++++++++++++++++ 3 files changed, 293 insertions(+), 3 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 2c715280..54898655 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -28,6 +28,8 @@ patch_huggingface_save_and_load_for_dtensors, patch_torch_optim_foreach_to_not_apply_to_dtensors, prepare_scattermoe, + patch_huggingface_clip_grad_norm_fsdp2, + patch_huggingface_fsdp2_load_full_state_dict ) logger = get_logger(__name__) @@ -143,10 +145,18 @@ def get_callbacks_and_ready_for_train( # call this to patch the HF save and load functions to be able # to save DTensors propery patch_huggingface_save_and_load_for_dtensors() - + + if not hasattr(accelerator.state.fsdp_plugin, "fsdp_version") or accelerator.state.fsdp_plugin.fsdp_version == 1: # call this to patch torch optim to not use - # foreach for dtensors - patch_torch_optim_foreach_to_not_apply_to_dtensors() + # foreach for dtensors only when fsdpv1 is used + # fsdpv2 with transformers does implicit replication to convert all to dtensors + # before grad norm and optimizer.step() operations + patch_torch_optim_foreach_to_not_apply_to_dtensors() + + if hasattr(accelerator.state.fsdp_plugin, "fsdp_version") and accelerator.state.fsdp_plugin.fsdp_version == 2: + # when EP and FSDPv2 is used + patch_huggingface_clip_grad_norm_fsdp2(accelerator) + patch_huggingface_fsdp2_load_full_state_dict() return callbacks diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py index a3f2117d..6a66ecb4 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py @@ -16,6 +16,8 @@ from .checkpoint_utils import ( patch_huggingface_save_and_load_for_dtensors, recover_safetensors_from_dcp, + patch_huggingface_clip_grad_norm_fsdp2, + patch_huggingface_fsdp2_load_full_state_dict, ) from .scattermoe_prepare import prepare_scattermoe diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 3d173c3b..d26647c0 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -19,6 +19,8 @@ import os import re import shutil +import types +import math # Third Party from accelerate.logging import get_logger @@ -35,12 +37,14 @@ from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME import torch import torch.distributed.checkpoint as dcp +from torch.distributed.tensor import DTensor # Local from .scattermoe_constants import ( FILE_SAFETENSOR_INDEX, PARAM_NAME_ROUTER_SCATTERMOE, PARAM_NAME_WEIGHT_SCATTERMOE, + KEY_EXPERT_PARALLEL, get_scattermoe_conv_spec_from_archs, ) from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor @@ -241,6 +245,15 @@ def patch_huggingface_save_and_load_for_dtensors(): patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model) patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer) +# function to monkey patch accelerator clip grad_norm +def patch_huggingface_clip_grad_norm_fsdp2(accelerator): + accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator) + +def patch_huggingface_fsdp2_load_full_state_dict(): + from fms_acceleration.model_patcher import patch_target_module + patch_target_module("accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict", fsdp2_load_full_state_dict) + patch_target_module("accelerate.utils.fsdp_utils.fsdp2_prepare_model", fsdp2_prepare_model) + # this function implements a trick to get the resolved cache file to acccess the safetensor # - NOTE: does not work if _dict_from_json_file is not called, such as in the case of GGUF files. @@ -613,6 +626,48 @@ def recover_safetensors_from_dcp( ) +def clip_grad_norm_(self, parameters, max_norm, norm_type=2): + """grad norm patch when EP is enabled""" + # code inspired from + # https://github.com/pytorch/torchtitan/blob/72b16b13abc88ba08f3e1796e5caee09abd94554/torchtitan/distributed/utils.py#L398 + ep_params = [] + non_ep_params = [] + ep_grads = [] + non_ep_grads = [] + + for p in parameters: + if p.grad is None: + continue + if p.device_mesh.mesh_dim_names and KEY_EXPERT_PARALLEL in p.device_mesh.mesh_dim_names: + ep_params.append(p) + ep_grads.append(p.grad) + else: + non_ep_params.append(p) + non_ep_grads.append(p.grad) + ep_grads_total_norm = torch.nn.utils.get_total_norm( + ep_grads, norm_type, False, True + ) + + if isinstance(ep_grads_total_norm, DTensor): + ep_grads_total_norm = ep_grads_total_norm.full_tensor() + + non_ep_grads_total_norm = torch.nn.utils.get_total_norm( + non_ep_grads, norm_type, False, True + ).full_tensor() + + if math.isinf(norm_type): + total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) + else: + total_norm = ( + ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + ) + total_norm **= 1.0 / norm_type + + torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, True) + torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, True) + + return total_norm + # have it serve as a conversion script if __name__ == "__main__": # Standard @@ -651,3 +706,226 @@ def recover_safetensors_from_dcp( recover_safetensors_from_dcp( args.checkpoint_dir, args.pretrained_model_name_or_path, args.output_dir ) + + +# code taken from HF accelerate and modified +def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the + parameters from rank 0 to all other ranks. This function modifies the model in-place. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): + The model to load the state dict into, expected to be on meta device or a VRAM spike can occur + full_sd (`dict`): The full state dict to load, can only be on rank 0 + """ + import torch.distributed as dist + from torch.distributed.tensor import distribute_tensor + + # Model was previously copied to meta device + meta_sharded_sd = model.state_dict() + sharded_sd = {} + + # Rank 0 distributes the full state dict to other ranks + def _infer_parameter_dtype(model, param_name, empty_param): + try: + old_param = model.get_parameter_or_buffer(param_name) + except AttributeError: + # Need this for LORA, as there some params are not *parameters* of sorts + base_param_name, local_param_name = param_name.rsplit(".", 1) + submodule = model.get_submodule(base_param_name) + old_param = getattr(submodule, local_param_name) + + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + casting_dtype = None + is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + casting_dtype = old_param.dtype + + return old_param is not None and old_param.is_contiguous(), casting_dtype + + def _cast_and_contiguous(tensor, to_contiguous, dtype): + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if to_contiguous: + tensor = tensor.contiguous() + return tensor + # ignored_params = get_parameters_from_modules(accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device) + ignored_params = {p.detach() for p in get_parameters_from_modules(accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device)} + if accelerator.is_main_process: + for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()): + # ignored params will not be on meta device + # and not handled by FSDP + if sharded_param.device != torch.device("meta"): + sharded_sd[param_name] = sharded_param + else: + device_mesh = sharded_param.device_mesh + full_param = full_param.detach().to(device_mesh.device_type) + dist.broadcast(full_param, src=0, group=dist.group.WORLD) + # if device_mesh.ndim > 1: + # for mesh_dim_name in device_mesh.mesh_dim_names: + # dist.broadcast(full_param, src=0, group=device_mesh.get_group(mesh_dim=mesh_dim_name)) + # else: + # dist.broadcast(full_param, src=0, group=device_mesh.get_group()) + sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_param, + ) + sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) + sharded_sd[param_name] = sharded_tensor + # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock + else: + for param_name, sharded_param in meta_sharded_sd.items(): + # ignored params will not be on meta device + # and not handled by FSDP + if sharded_param.device != torch.device("meta"): + sharded_sd[param_name] = sharded_param + else: + device_mesh = sharded_param.device_mesh + full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype) + dist.broadcast(full_tensor, src=0, group=dist.group.WORLD) + # if device_mesh.ndim > 1: + # for mesh_dim_name in device_mesh.mesh_dim_names: + # dist.broadcast(full_tensor, src=0, group=device_mesh.get_group(mesh_dim=mesh_dim_name)) + # else: + # dist.broadcast(full_tensor, src=0, group=device_mesh.get_group()) + sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_tensor, + ) + sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) + sharded_sd[param_name] = sharded_tensor + + # we set `assign=True` because our params are on meta device + model.load_state_dict(sharded_sd, assign=True) + return model + +# code taken from HF accelerate and modified +def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: + """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): The model to prepare + + Returns: + `torch.nn.Module`: Prepared model + """ + from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard + + is_type_fsdp = isinstance(model, FSDPModule) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) + ) + if is_type_fsdp: + return model + + fsdp2_plugin = accelerator.state.fsdp_plugin + + fsdp2_plugin.set_auto_wrap_policy(model) + + original_sd = model.state_dict() + mesh = getattr(accelerator, "torch_device_mesh", None) + + fsdp2_kwargs = { + "reshard_after_forward": fsdp2_plugin.reshard_after_forward, + "offload_policy": fsdp2_plugin.cpu_offload, + # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` + "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), + "mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None, + "ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device), + } + + model_has_params4bit = False + for name, param in model.named_parameters(): + # this is a temporary fix whereby loading models with bnb params cannot be moved from + # GPU to a meta device due with FSDP2 because torch operations don't return the original class type + # bypassing the move to meta will still cause the VRAM spike, but at least it still will load + if param.__class__.__name__ == "Params4bit": + model_has_params4bit = True + break + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard` + # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device + # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU + # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike + + # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device + # Also, these buffers aren't getting sharded by default + # We get the FQNs of all non-persistent buffers, to re-register them after + non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True) + original_non_persistent_buffers = copy.deepcopy( + {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns} + ) + # We move the model parameters to meta device that are managed by FSDPv2, + # as then sharding happens on meta device + with torch.no_grad(): + for _, module in model.named_modules(): + for param_name, param in list(module.named_parameters(recurse=False)): + if param not in fsdp2_kwargs["ignored_params"]: + # Create new parameter on meta device + meta_param = torch.nn.Parameter( + torch.empty(param.shape, dtype=param.dtype, device="meta"), requires_grad=param.requires_grad + ) + setattr(module, param_name, meta_param) + # model = model.to(torch.device("meta")) + # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage + # We assume `transformers` models have a `tie_weights` method if they support it + if hasattr(model, "tie_weights"): + model.tie_weights() + + auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) + if auto_wrap_policy_func is not None: + # We skip the model itself, as that one is always wrapped + for module in get_module_children_bottom_up(model)[:-1]: + if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule): + fully_shard(module, **fsdp2_kwargs) + + if not isinstance(model, FSDPModule): + fully_shard(model, **fsdp2_kwargs) + + if fsdp2_plugin.cpu_ram_efficient_loading: + # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights + # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly + fsdp2_load_full_state_dict(accelerator, model, original_sd) + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # We re-register the buffers, as they may not be in the state_dict + for fqn, buffer_tensor in original_non_persistent_buffers.items(): + buffer_tensor = buffer_tensor.to(accelerator.device) + + if "." in fqn: + parent_fqn, local_buffer_name = fqn.rsplit(".", 1) + parent_module = model.get_submodule(parent_fqn) + else: + local_buffer_name = fqn + parent_module = model + + parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False) + + # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie + # Needs to be called both here and above + # removing this call makes the have slightly different loss + # removing the call above leads to extra memory usage as explained in the comment above + if hasattr(model, "tie_weights"): + model.tie_weights() + + # There is no `dtype` attribution for nn.Module + # Set it to None if it doesn't exist and do the upcast always + model_dtype = getattr(model, "dtype", None) + if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32): + # We upcast the model according to `deepspeed`'s implementation + # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section + model = model.to(torch.float32) + if accelerator.is_main_process: + # TODO(siro1): Add a warning for each parameter that was upcasted + warnings.warn( + "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints." + ) + return model \ No newline at end of file From d1e5ff45df4a3038b3e062012e47c00efa3e5567 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 15 Oct 2025 02:06:40 +0530 Subject: [PATCH 2/4] fsdp2 patches Signed-off-by: Mehant Kammakomati --- .../framework_plugin_scattermoe.py | 24 ++-- .../fms_acceleration_moe/utils/__init__.py | 4 +- .../utils/checkpoint_utils.py | 107 +++++++++++++----- 3 files changed, 95 insertions(+), 40 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 54898655..93e6048d 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -25,11 +25,11 @@ # Local from .utils import ( + patch_huggingface_clip_grad_norm_fsdp2, + patch_huggingface_fsdp2_load_full_state_dict, patch_huggingface_save_and_load_for_dtensors, patch_torch_optim_foreach_to_not_apply_to_dtensors, prepare_scattermoe, - patch_huggingface_clip_grad_norm_fsdp2, - patch_huggingface_fsdp2_load_full_state_dict ) logger = get_logger(__name__) @@ -145,15 +145,21 @@ def get_callbacks_and_ready_for_train( # call this to patch the HF save and load functions to be able # to save DTensors propery patch_huggingface_save_and_load_for_dtensors() - - if not hasattr(accelerator.state.fsdp_plugin, "fsdp_version") or accelerator.state.fsdp_plugin.fsdp_version == 1: - # call this to patch torch optim to not use - # foreach for dtensors only when fsdpv1 is used - # fsdpv2 with transformers does implicit replication to convert all to dtensors - # before grad norm and optimizer.step() operations + + if ( + not hasattr(accelerator.state.fsdp_plugin, "fsdp_version") + or accelerator.state.fsdp_plugin.fsdp_version == 1 + ): + # call this to patch torch optim to not use + # foreach for dtensors only when fsdpv1 is used + # fsdpv2 with transformers does implicit replication to convert all to dtensors + # before grad norm and optimizer.step() operations patch_torch_optim_foreach_to_not_apply_to_dtensors() - if hasattr(accelerator.state.fsdp_plugin, "fsdp_version") and accelerator.state.fsdp_plugin.fsdp_version == 2: + if ( + hasattr(accelerator.state.fsdp_plugin, "fsdp_version") + and accelerator.state.fsdp_plugin.fsdp_version == 2 + ): # when EP and FSDPv2 is used patch_huggingface_clip_grad_norm_fsdp2(accelerator) patch_huggingface_fsdp2_load_full_state_dict() diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py index 6a66ecb4..eff545e8 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py @@ -14,10 +14,10 @@ # Local from .checkpoint_utils import ( - patch_huggingface_save_and_load_for_dtensors, - recover_safetensors_from_dcp, patch_huggingface_clip_grad_norm_fsdp2, patch_huggingface_fsdp2_load_full_state_dict, + patch_huggingface_save_and_load_for_dtensors, + recover_safetensors_from_dcp, ) from .scattermoe_prepare import prepare_scattermoe diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index d26647c0..7ea7d393 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -16,11 +16,11 @@ from collections import defaultdict from typing import Dict, List, Union import json +import math import os import re import shutil import types -import math # Third Party from accelerate.logging import get_logger @@ -33,18 +33,18 @@ ) from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from torch.distributed.tensor import DTensor from transformers import PretrainedConfig from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME import torch import torch.distributed.checkpoint as dcp -from torch.distributed.tensor import DTensor # Local from .scattermoe_constants import ( FILE_SAFETENSOR_INDEX, + KEY_EXPERT_PARALLEL, PARAM_NAME_ROUTER_SCATTERMOE, PARAM_NAME_WEIGHT_SCATTERMOE, - KEY_EXPERT_PARALLEL, get_scattermoe_conv_spec_from_archs, ) from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor @@ -245,14 +245,23 @@ def patch_huggingface_save_and_load_for_dtensors(): patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model) patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer) + # function to monkey patch accelerator clip grad_norm def patch_huggingface_clip_grad_norm_fsdp2(accelerator): accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator) - + + def patch_huggingface_fsdp2_load_full_state_dict(): + # Third Party from fms_acceleration.model_patcher import patch_target_module - patch_target_module("accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict", fsdp2_load_full_state_dict) - patch_target_module("accelerate.utils.fsdp_utils.fsdp2_prepare_model", fsdp2_prepare_model) + + patch_target_module( + "accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict", + fsdp2_load_full_state_dict, + ) + patch_target_module( + "accelerate.utils.fsdp_utils.fsdp2_prepare_model", fsdp2_prepare_model + ) # this function implements a trick to get the resolved cache file to acccess the safetensor @@ -638,7 +647,10 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): for p in parameters: if p.grad is None: continue - if p.device_mesh.mesh_dim_names and KEY_EXPERT_PARALLEL in p.device_mesh.mesh_dim_names: + if ( + p.device_mesh.mesh_dim_names + and KEY_EXPERT_PARALLEL in p.device_mesh.mesh_dim_names + ): ep_params.append(p) ep_grads.append(p.grad) else: @@ -658,9 +670,7 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): if math.isinf(norm_type): total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) else: - total_norm = ( - ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type - ) + total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type total_norm **= 1.0 / norm_type torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, True) @@ -668,6 +678,7 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): return total_norm + # have it serve as a conversion script if __name__ == "__main__": # Standard @@ -720,8 +731,9 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic The model to load the state dict into, expected to be on meta device or a VRAM spike can occur full_sd (`dict`): The full state dict to load, can only be on rank 0 """ - import torch.distributed as dist + # Third Party from torch.distributed.tensor import distribute_tensor + import torch.distributed as dist # Model was previously copied to meta device meta_sharded_sd = model.state_dict() @@ -739,7 +751,9 @@ def _infer_parameter_dtype(model, param_name, empty_param): is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") casting_dtype = None - is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + is_param_float8_e4m3fn = ( + is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + ) if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: casting_dtype = old_param.dtype @@ -752,11 +766,19 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): if to_contiguous: tensor = tensor.contiguous() return tensor + # ignored_params = get_parameters_from_modules(accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device) - ignored_params = {p.detach() for p in get_parameters_from_modules(accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device)} + ignored_params = { + p.detach() + for p in get_parameters_from_modules( + accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device + ) + } if accelerator.is_main_process: - for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()): - # ignored params will not be on meta device + for (param_name, full_param), sharded_param in zip( + full_sd.items(), meta_sharded_sd.values() + ): + # ignored params will not be on meta device # and not handled by FSDP if sharded_param.device != torch.device("meta"): sharded_sd[param_name] = sharded_param @@ -769,43 +791,56 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): # dist.broadcast(full_param, src=0, group=device_mesh.get_group(mesh_dim=mesh_dim_name)) # else: # dist.broadcast(full_param, src=0, group=device_mesh.get_group()) - sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements) + sharded_tensor = distribute_tensor( + full_param, device_mesh, sharded_param.placements + ) to_contiguous, casting_dtype = _infer_parameter_dtype( model, param_name, full_param, ) - sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) + sharded_tensor = _cast_and_contiguous( + sharded_tensor, to_contiguous, casting_dtype + ) sharded_sd[param_name] = sharded_tensor # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock else: for param_name, sharded_param in meta_sharded_sd.items(): - # ignored params will not be on meta device + # ignored params will not be on meta device # and not handled by FSDP if sharded_param.device != torch.device("meta"): sharded_sd[param_name] = sharded_param else: device_mesh = sharded_param.device_mesh - full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype) + full_tensor = torch.empty( + sharded_param.size(), + device=device_mesh.device_type, + dtype=sharded_param.dtype, + ) dist.broadcast(full_tensor, src=0, group=dist.group.WORLD) # if device_mesh.ndim > 1: # for mesh_dim_name in device_mesh.mesh_dim_names: # dist.broadcast(full_tensor, src=0, group=device_mesh.get_group(mesh_dim=mesh_dim_name)) # else: # dist.broadcast(full_tensor, src=0, group=device_mesh.get_group()) - sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements) + sharded_tensor = distribute_tensor( + full_tensor, device_mesh, sharded_param.placements + ) to_contiguous, casting_dtype = _infer_parameter_dtype( model, param_name, full_tensor, ) - sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) + sharded_tensor = _cast_and_contiguous( + sharded_tensor, to_contiguous, casting_dtype + ) sharded_sd[param_name] = sharded_tensor # we set `assign=True` because our params are on meta device model.load_state_dict(sharded_sd, assign=True) return model + # code taken from HF accelerate and modified def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. @@ -817,6 +852,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: Returns: `torch.nn.Module`: Prepared model """ + # Third Party from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard is_type_fsdp = isinstance(model, FSDPModule) or ( @@ -837,8 +873,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: "offload_policy": fsdp2_plugin.cpu_offload, # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), - "mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None, - "ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device), + "mesh": ( + mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] + if mesh is not None + else None + ), + "ignored_params": get_parameters_from_modules( + fsdp2_plugin.ignored_modules, model, accelerator.device + ), } model_has_params4bit = False @@ -859,11 +901,13 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device # Also, these buffers aren't getting sharded by default # We get the FQNs of all non-persistent buffers, to re-register them after - non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True) + non_persistent_buffer_fqns = get_non_persistent_buffers( + model, recurse=True, fqns=True + ) original_non_persistent_buffers = copy.deepcopy( {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns} ) - # We move the model parameters to meta device that are managed by FSDPv2, + # We move the model parameters to meta device that are managed by FSDPv2, # as then sharding happens on meta device with torch.no_grad(): for _, module in model.named_modules(): @@ -871,7 +915,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if param not in fsdp2_kwargs["ignored_params"]: # Create new parameter on meta device meta_param = torch.nn.Parameter( - torch.empty(param.shape, dtype=param.dtype, device="meta"), requires_grad=param.requires_grad + torch.empty(param.shape, dtype=param.dtype, device="meta"), + requires_grad=param.requires_grad, ) setattr(module, param_name, meta_param) # model = model.to(torch.device("meta")) @@ -907,7 +952,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: local_buffer_name = fqn parent_module = model - parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False) + parent_module.register_buffer( + local_buffer_name, buffer_tensor, persistent=False + ) # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie # Needs to be called both here and above @@ -919,7 +966,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: # There is no `dtype` attribution for nn.Module # Set it to None if it doesn't exist and do the upcast always model_dtype = getattr(model, "dtype", None) - if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32): + if accelerator.mixed_precision != "no" and ( + model_dtype is None or model_dtype != torch.float32 + ): # We upcast the model according to `deepspeed`'s implementation # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section model = model.to(torch.float32) @@ -928,4 +977,4 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: warnings.warn( "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints." ) - return model \ No newline at end of file + return model From abee5a8f29fef4a5235224a11626dd04543fa36e Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 15 Oct 2025 02:19:07 +0530 Subject: [PATCH 3/4] fsdp2 patches Signed-off-by: Mehant Kammakomati --- .../utils/checkpoint_utils.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 7ea7d393..86bceb13 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -253,6 +253,7 @@ def patch_huggingface_clip_grad_norm_fsdp2(accelerator): def patch_huggingface_fsdp2_load_full_state_dict(): # Third Party + # pylint: disable=import-outside-toplevel from fms_acceleration.model_patcher import patch_target_module patch_target_module( @@ -722,16 +723,19 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): # code taken from HF accelerate and modified def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): """ - Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the - parameters from rank 0 to all other ranks. This function modifies the model in-place. + Loads the full state dict (could be only on rank 0) into the sharded model. + This is done by broadcasting the parameters from rank 0 to all other ranks. + This function modifies the model in-place. Args: accelerator (`Accelerator`): The accelerator instance model (`torch.nn.Module`): - The model to load the state dict into, expected to be on meta device or a VRAM spike can occur + The model to load the state dict into, expected to be on meta device + or a VRAM spike can occur full_sd (`dict`): The full state dict to load, can only be on rank 0 """ # Third Party + # pylint: disable=import-outside-toplevel from torch.distributed.tensor import distribute_tensor import torch.distributed as dist @@ -767,9 +771,9 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): tensor = tensor.contiguous() return tensor - # ignored_params = get_parameters_from_modules(accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device) ignored_params = { p.detach() + # pylint: disable=undefined-variable for p in get_parameters_from_modules( accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device ) @@ -786,11 +790,6 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): device_mesh = sharded_param.device_mesh full_param = full_param.detach().to(device_mesh.device_type) dist.broadcast(full_param, src=0, group=dist.group.WORLD) - # if device_mesh.ndim > 1: - # for mesh_dim_name in device_mesh.mesh_dim_names: - # dist.broadcast(full_param, src=0, group=device_mesh.get_group(mesh_dim=mesh_dim_name)) - # else: - # dist.broadcast(full_param, src=0, group=device_mesh.get_group()) sharded_tensor = distribute_tensor( full_param, device_mesh, sharded_param.placements ) @@ -818,11 +817,6 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): dtype=sharded_param.dtype, ) dist.broadcast(full_tensor, src=0, group=dist.group.WORLD) - # if device_mesh.ndim > 1: - # for mesh_dim_name in device_mesh.mesh_dim_names: - # dist.broadcast(full_tensor, src=0, group=device_mesh.get_group(mesh_dim=mesh_dim_name)) - # else: - # dist.broadcast(full_tensor, src=0, group=device_mesh.get_group()) sharded_tensor = distribute_tensor( full_tensor, device_mesh, sharded_param.placements ) @@ -843,7 +837,8 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): # code taken from HF accelerate and modified def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: - """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. + """Prepares the model for FSDP2 in-place. Also returns the model to avoid + misuse of the original model. Args: accelerator (`Accelerator`): The accelerator instance @@ -853,9 +848,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: `torch.nn.Module`: Prepared model """ # Third Party + # pylint: disable=import-outside-toplevel from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard is_type_fsdp = isinstance(model, FSDPModule) or ( + # pylint: disable=undefined-variable is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) ) if is_type_fsdp: @@ -878,32 +875,28 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if mesh is not None else None ), + # pylint: disable=undefined-variable "ignored_params": get_parameters_from_modules( fsdp2_plugin.ignored_modules, model, accelerator.device ), } model_has_params4bit = False - for name, param in model.named_parameters(): - # this is a temporary fix whereby loading models with bnb params cannot be moved from - # GPU to a meta device due with FSDP2 because torch operations don't return the original class type - # bypassing the move to meta will still cause the VRAM spike, but at least it still will load + for _, param in model.named_parameters(): + # this is a temporary fix whereby loading models with bnb params + # cannot be moved from GPU to a meta device due with FSDP2 because + # torch operations don't return the original class type bypassing the + # move to meta will still cause the VRAM spike, but at least it still will load if param.__class__.__name__ == "Params4bit": model_has_params4bit = True break if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: - # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard` - # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device - # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU - # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike - - # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device - # Also, these buffers aren't getting sharded by default - # We get the FQNs of all non-persistent buffers, to re-register them after + # pylint: disable=undefined-variable non_persistent_buffer_fqns = get_non_persistent_buffers( model, recurse=True, fqns=True ) + # pylint: disable=undefined-variable original_non_persistent_buffers = copy.deepcopy( {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns} ) @@ -920,14 +913,17 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: ) setattr(module, param_name, meta_param) # model = model.to(torch.device("meta")) - # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage + # We need to re-tie the weights, not exactly sure why, but if we don't do this, + # reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage # We assume `transformers` models have a `tie_weights` method if they support it if hasattr(model, "tie_weights"): model.tie_weights() + # pylint: disable=undefined-variable auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) if auto_wrap_policy_func is not None: # We skip the model itself, as that one is always wrapped + # pylint: disable=undefined-variable for module in get_module_children_bottom_up(model)[:-1]: if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs) @@ -937,7 +933,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if fsdp2_plugin.cpu_ram_efficient_loading: # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights - # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly + # Other ranks have an empty model on `meta` device, so we need to distribute + # the weights properly fsdp2_load_full_state_dict(accelerator, model, original_sd) if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: @@ -970,11 +967,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: model_dtype is None or model_dtype != torch.float32 ): # We upcast the model according to `deepspeed`'s implementation - # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section + # More info about this can be found in `accelerator.py:prepare_model`s + # FSDP1 section model = model.to(torch.float32) if accelerator.is_main_process: # TODO(siro1): Add a warning for each parameter that was upcasted + # pylint: disable=undefined-variable warnings.warn( - "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints." + "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no')" + "may affect the precision of model checkpoints." ) return model From f3a026082b9bc3154f67fee4faaaf81be0d827ca Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 15 Oct 2025 11:20:32 +0530 Subject: [PATCH 4/4] fsdp2 patches Signed-off-by: Mehant Kammakomati --- .../fms_acceleration_moe/utils/checkpoint_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 86bceb13..3269e1f2 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -723,14 +723,14 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): # code taken from HF accelerate and modified def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): """ - Loads the full state dict (could be only on rank 0) into the sharded model. - This is done by broadcasting the parameters from rank 0 to all other ranks. + Loads the full state dict (could be only on rank 0) into the sharded model. + This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place. Args: accelerator (`Accelerator`): The accelerator instance model (`torch.nn.Module`): - The model to load the state dict into, expected to be on meta device + The model to load the state dict into, expected to be on meta device or a VRAM spike can occur full_sd (`dict`): The full state dict to load, can only be on rank 0 """ @@ -837,7 +837,7 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): # code taken from HF accelerate and modified def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: - """Prepares the model for FSDP2 in-place. Also returns the model to avoid + """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. Args: @@ -853,7 +853,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: is_type_fsdp = isinstance(model, FSDPModule) or ( # pylint: disable=undefined-variable - is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) + is_compiled_module(model) + and isinstance(model._orig_mod, FSDPModule) ) if is_type_fsdp: return model