Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
patch_huggingface_clip_grad_norm_fsdp2,
patch_huggingface_fsdp2_load_full_state_dict,
patch_huggingface_save_and_load_for_dtensors,
patch_prepare_sd_options,
patch_torch_optim_foreach_to_not_apply_to_dtensors,
prepare_scattermoe,
)
Expand Down Expand Up @@ -118,6 +119,12 @@ def get_callbacks_and_ready_for_train(
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
):
if (
hasattr(accelerator.state.fsdp_plugin, "fsdp_version")
and accelerator.state.fsdp_plugin.fsdp_version == 2
):
# when FSDPv2 is used
patch_prepare_sd_options()

if not self._disable_distributed:
# - use an internal function call to get the no split
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
patch_huggingface_clip_grad_norm_fsdp2,
patch_huggingface_fsdp2_load_full_state_dict,
patch_huggingface_save_and_load_for_dtensors,
patch_prepare_sd_options,
recover_safetensors_from_dcp,
)
from .scattermoe_prepare import prepare_scattermoe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,15 @@ def save_fsdp_model(
def save_fsdp_optimizer(
fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0
):

if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
raise NotImplementedError(
"Checkpointing for megablocks only enabled for sharded state dict."
)

sd_options = _prepare_sd_options(fsdp_plugin)
# get the state dicts for model and optimize
(model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)
(model_state_dict, optimizer_state_dict) = get_state_dict(
model, optimizer, options=sd_options
)

# filter out lora state dict
# TODO: Once expert layers are supported for LoRA tuning
Expand Down Expand Up @@ -157,6 +158,28 @@ def save_fsdp_optimizer(
logger.info(f"Optimizer state saved in {ckpt_opt}")


def _prepare_sd_options(fsdp_plugin):
sd_options = None

# we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0
if fsdp_plugin.fsdp_version == 2:
# pylint: disable=import-outside-toplevel
# Third Party
from torch.distributed.checkpoint.state_dict import StateDictOptions

sd_options = StateDictOptions(
full_state_dict=fsdp_plugin.state_dict_type
== StateDictType.FULL_STATE_DICT,
cpu_offload=getattr(fsdp_plugin.state_dict_config, "offload_to_cpu", False),
broadcast_from_rank0=getattr(
fsdp_plugin.state_dict_config, "rank0_only", False
),
flatten_optimizer_state_dict=True,
)

return sd_options


# rewrite of func from accelerate.utils.fsdp_utils.py
# - empty function, main logic in load_fsdp_optimizer (see below).
def load_fsdp_model(
Expand All @@ -178,15 +201,16 @@ def load_fsdp_optimizer(
optimizer_index=0,
adapter_only=False,
):

accelerator.wait_for_everyone()
if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
raise NotImplementedError(
"Checkpointing for megablocks only enabled for sharded state dict."
)

sd_options = _prepare_sd_options(fsdp_plugin)
# - get the state dicts
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
model_state_dict, optimizer_state_dict = get_state_dict(
model, optimizer, options=sd_options
)

# - load the model state dict
ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
Expand All @@ -210,6 +234,7 @@ def load_fsdp_optimizer(
optimizer,
model_state_dict=model_state_dict,
optim_state_dict=optimizer_state_dict,
options=sd_options,
)

# FIXME:
Expand Down Expand Up @@ -246,6 +271,16 @@ def patch_huggingface_save_and_load_for_dtensors():
patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)


def patch_prepare_sd_options():
# Third Party
# pylint: disable=import-outside-toplevel
from fms_acceleration.model_patcher import patch_target_module

patch_target_module(
"accelerate.utils.fsdp_utils._prepare_sd_options", _prepare_sd_options
)


# function to monkey patch accelerator clip grad_norm
def patch_huggingface_clip_grad_norm_fsdp2(accelerator):
accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def calculate_settings(n):
pass

# import guard added by [email protected]
from transformers.utils.import_utils import _bitsandbytes_available
try:
from transformers.utils.import_utils import _bitsandbytes_available
except ImportError:
from transformers.utils.import_utils import is_bitsandbytes_available
_bitsandbytes_available = is_bitsandbytes_available()

if _bitsandbytes_available:
import bitsandbytes as bnb
get_ptr = bnb.functional.get_ptr
Expand Down