Skip to content

Commit 2990230

Browse files
authored
fix: state dict patch and _bitsandbytes_available (#161)
* fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: ignored params for ep 8 Signed-off-by: Mehant Kammakomati <[email protected]> * fix: patch sd optionst Signed-off-by: Mehant Kammakomati <[email protected]> * debug Signed-off-by: Mehant Kammakomati <[email protected]> * fix Signed-off-by: Mehant Kammakomati <[email protected]> * fix: fsdp2 Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 73b4f6b commit 2990230

File tree

4 files changed

+55
-7
lines changed

4 files changed

+55
-7
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
patch_huggingface_clip_grad_norm_fsdp2,
2929
patch_huggingface_fsdp2_load_full_state_dict,
3030
patch_huggingface_save_and_load_for_dtensors,
31+
patch_prepare_sd_options,
3132
patch_torch_optim_foreach_to_not_apply_to_dtensors,
3233
prepare_scattermoe,
3334
)
@@ -118,6 +119,12 @@ def get_callbacks_and_ready_for_train(
118119
accelerator is not None
119120
and getattr(accelerator.state, "fsdp_plugin", None) is not None
120121
):
122+
if (
123+
hasattr(accelerator.state.fsdp_plugin, "fsdp_version")
124+
and accelerator.state.fsdp_plugin.fsdp_version == 2
125+
):
126+
# when FSDPv2 is used
127+
patch_prepare_sd_options()
121128

122129
if not self._disable_distributed:
123130
# - use an internal function call to get the no split

plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
patch_huggingface_clip_grad_norm_fsdp2,
1818
patch_huggingface_fsdp2_load_full_state_dict,
1919
patch_huggingface_save_and_load_for_dtensors,
20+
patch_prepare_sd_options,
2021
recover_safetensors_from_dcp,
2122
)
2223
from .scattermoe_prepare import prepare_scattermoe

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,15 @@ def save_fsdp_model(
107107
def save_fsdp_optimizer(
108108
fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0
109109
):
110-
111110
if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
112111
raise NotImplementedError(
113112
"Checkpointing for megablocks only enabled for sharded state dict."
114113
)
115-
114+
sd_options = _prepare_sd_options(fsdp_plugin)
116115
# get the state dicts for model and optimize
117-
(model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)
116+
(model_state_dict, optimizer_state_dict) = get_state_dict(
117+
model, optimizer, options=sd_options
118+
)
118119

119120
# filter out lora state dict
120121
# TODO: Once expert layers are supported for LoRA tuning
@@ -157,6 +158,28 @@ def save_fsdp_optimizer(
157158
logger.info(f"Optimizer state saved in {ckpt_opt}")
158159

159160

161+
def _prepare_sd_options(fsdp_plugin):
162+
sd_options = None
163+
164+
# we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0
165+
if fsdp_plugin.fsdp_version == 2:
166+
# pylint: disable=import-outside-toplevel
167+
# Third Party
168+
from torch.distributed.checkpoint.state_dict import StateDictOptions
169+
170+
sd_options = StateDictOptions(
171+
full_state_dict=fsdp_plugin.state_dict_type
172+
== StateDictType.FULL_STATE_DICT,
173+
cpu_offload=getattr(fsdp_plugin.state_dict_config, "offload_to_cpu", False),
174+
broadcast_from_rank0=getattr(
175+
fsdp_plugin.state_dict_config, "rank0_only", False
176+
),
177+
flatten_optimizer_state_dict=True,
178+
)
179+
180+
return sd_options
181+
182+
160183
# rewrite of func from accelerate.utils.fsdp_utils.py
161184
# - empty function, main logic in load_fsdp_optimizer (see below).
162185
def load_fsdp_model(
@@ -178,15 +201,16 @@ def load_fsdp_optimizer(
178201
optimizer_index=0,
179202
adapter_only=False,
180203
):
181-
182204
accelerator.wait_for_everyone()
183205
if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
184206
raise NotImplementedError(
185207
"Checkpointing for megablocks only enabled for sharded state dict."
186208
)
187-
209+
sd_options = _prepare_sd_options(fsdp_plugin)
188210
# - get the state dicts
189-
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
211+
model_state_dict, optimizer_state_dict = get_state_dict(
212+
model, optimizer, options=sd_options
213+
)
190214

191215
# - load the model state dict
192216
ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
@@ -210,6 +234,7 @@ def load_fsdp_optimizer(
210234
optimizer,
211235
model_state_dict=model_state_dict,
212236
optim_state_dict=optimizer_state_dict,
237+
options=sd_options,
213238
)
214239

215240
# FIXME:
@@ -246,6 +271,16 @@ def patch_huggingface_save_and_load_for_dtensors():
246271
patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)
247272

248273

274+
def patch_prepare_sd_options():
275+
# Third Party
276+
# pylint: disable=import-outside-toplevel
277+
from fms_acceleration.model_patcher import patch_target_module
278+
279+
patch_target_module(
280+
"accelerate.utils.fsdp_utils._prepare_sd_options", _prepare_sd_options
281+
)
282+
283+
249284
# function to monkey patch accelerator clip grad_norm
250285
def patch_huggingface_clip_grad_norm_fsdp2(accelerator):
251286
accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator)

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ def calculate_settings(n):
3131
pass
3232

3333
# import guard added by [email protected]
34-
from transformers.utils.import_utils import _bitsandbytes_available
34+
try:
35+
from transformers.utils.import_utils import _bitsandbytes_available
36+
except ImportError:
37+
from transformers.utils.import_utils import is_bitsandbytes_available
38+
_bitsandbytes_available = is_bitsandbytes_available()
39+
3540
if _bitsandbytes_available:
3641
import bitsandbytes as bnb
3742
get_ptr = bnb.functional.get_ptr

0 commit comments

Comments
 (0)