Skip to content

Commit be45088

Browse files
committed
feat: support moe hf chkpt
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 791bdd9 commit be45088

File tree

8 files changed

+157
-65
lines changed

8 files changed

+157
-65
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,30 @@
1414

1515
# Standard
1616
from typing import Dict, Tuple
17+
import asyncio
18+
import os
1719

1820
# Third Party
21+
from accelerate import Accelerator
1922
from fms_acceleration import AccelerationPlugin
2023
from peft import LoraConfig
21-
from transformers import TrainingArguments
24+
from transformers import (
25+
Trainer,
26+
TrainerCallback,
27+
TrainerControl,
28+
TrainerState,
29+
TrainingArguments,
30+
)
31+
from transformers.trainer import TRAINING_ARGS_NAME
32+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
2233
import torch
2334

2435
# Local
2536
from .utils import (
2637
patch_huggingface_save_and_load_for_dtensors,
2738
patch_torch_optim_foreach_to_not_apply_to_dtensors,
2839
prepare_scattermoe,
40+
recover_safetensors_from_dcp,
2941
)
3042

3143

@@ -90,10 +102,81 @@ def augmentation(
90102
return model, modifiable_args
91103

92104
def get_callbacks_and_ready_for_train(
93-
self, model: torch.nn.Module = None, accelerator=None
105+
self,
106+
model: torch.nn.Module = None,
107+
accelerator: Accelerator = None,
108+
trainer: Trainer = None,
109+
pretrained_module_name_or_path: str = None,
94110
):
95111

96112
callbacks = []
113+
114+
class ConvertAndSaveHFCheckpointAtEverySave(TrainerCallback):
115+
def __init__(self, pretrained_model_name_or_path: str, trainer: Trainer):
116+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
117+
self.trainer = trainer
118+
119+
def on_save(
120+
self,
121+
args: TrainingArguments,
122+
state: TrainerState,
123+
control: TrainerControl,
124+
**kwargs,
125+
):
126+
"""
127+
Save all HF files and convert dcp checkpoint to safetensors at every save operation.
128+
"""
129+
130+
async def checkpoint():
131+
checkpoint_dir = os.path.join(
132+
args.output_dir,
133+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
134+
)
135+
hf_converted_output_dir = os.path.join(
136+
checkpoint_dir, "hf_converted_checkpoint"
137+
)
138+
if os.path.exists(hf_converted_output_dir):
139+
# if the folder already exists
140+
# we return, since this is possible to happen
141+
# saving the checkpointing at the end of the training
142+
return
143+
os.mkdir(hf_converted_output_dir)
144+
try:
145+
recover_safetensors_from_dcp(
146+
checkpoint_dir,
147+
self.pretrained_model_name_or_path,
148+
hf_converted_output_dir,
149+
)
150+
# save tokenizer
151+
if self.trainer.processing_class:
152+
self.trainer.processing_class.save_pretrained(
153+
hf_converted_output_dir
154+
)
155+
# save training args
156+
torch.save(
157+
args,
158+
os.path.join(
159+
hf_converted_output_dir, TRAINING_ARGS_NAME
160+
),
161+
)
162+
# save model config files
163+
self.trainer.model.config.save_pretrained(
164+
hf_converted_output_dir
165+
)
166+
167+
except Exception as e:
168+
raise ValueError(
169+
f"Failed to convert the checkpoint {checkpoint_dir} to a HF compatible checkpoint"
170+
) from e
171+
if state.is_world_process_zero:
172+
asyncio.run(checkpoint())
173+
174+
callbacks.append(
175+
ConvertAndSaveHFCheckpointAtEverySave(
176+
pretrained_model_name_or_path=pretrained_module_name_or_path,
177+
trainer=trainer,
178+
)
179+
)
97180
if (
98181
accelerator is not None
99182
and getattr(accelerator.state, "fsdp_plugin", None) is not None

plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# limitations under the License.
1414

1515
# Local
16-
from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors
16+
from .checkpoint_utils import (
17+
patch_huggingface_save_and_load_for_dtensors,
18+
recover_safetensors_from_dcp,
19+
)
1720
from .scattermoe_prepare import prepare_scattermoe
1821

1922
# this is a special patch function to disable foreach for

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -457,75 +457,38 @@ def save_sharded_safetensors(
457457
# --------------------------- SCRIPT -------------------------
458458

459459

460-
# have it serve as a conversion script
461-
if __name__ == "__main__":
462-
# Standard
463-
import argparse
464-
465-
parser = argparse.ArgumentParser(
466-
description=(
467-
"Utility for converting ScatterMoE checkpoint back to the "
468-
"orginal state dict format. "
469-
"The ScatterMoE checkpoint was saved after the pretrained model "
470-
"had been converted by a module swap, hence the state dict will "
471-
"no longer resemble the original. This utility creaes"
472-
)
473-
)
474-
475-
parser.add_argument(
476-
"checkpoint_dir",
477-
help="Path to the checkpoint.",
478-
)
479-
480-
parser.add_argument(
481-
"output_dir", help="Path to the location to write the converted checkpoint."
482-
)
483-
484-
parser.add_argument(
485-
"pretrained_model_name_or_path",
486-
help=(
487-
"In order to reconstruct the state dict, we requre hints from "
488-
"the original pretrained model checkpoint (from which this "
489-
"checkpoint is obtained)."
490-
),
491-
default=None,
492-
)
493-
494-
args = parser.parse_args()
495-
496-
# search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
497-
# start with FSDP_MODEL_NAME
498-
if args.checkpoint_dir.startswith(FSDP_MODEL_NAME):
499-
checkpoint_dir = args.checkpoint_dir
460+
def recover_safetensors_from_dcp(
461+
checkpoint_dir, pretrained_model_name_or_path, output_dir
462+
):
463+
if checkpoint_dir.startswith(FSDP_MODEL_NAME):
500464
loader = get_state_dict_from_dcp_checkpoint
501465
else:
502-
checkpoint_dir = [
466+
fsdp_checkpoint_dirs = [
503467
x
504-
for x in os.listdir(args.checkpoint_dir)
505-
if os.path.isdir(os.path.join(args.checkpoint_dir, x))
468+
for x in os.listdir(checkpoint_dir)
469+
if os.path.isdir(os.path.join(checkpoint_dir, x))
506470
and x.startswith(FSDP_MODEL_NAME)
507471
]
508-
if len(checkpoint_dir) == 1:
509-
checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0])
472+
if len(fsdp_checkpoint_dirs) == 1:
473+
checkpoint_dir = os.path.join(checkpoint_dir, fsdp_checkpoint_dirs[0])
510474
loader = get_state_dict_from_dcp_checkpoint
511-
elif len(checkpoint_dir) > 1:
475+
elif len(fsdp_checkpoint_dirs) > 1:
512476
raise ValueError(
513-
f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} "
477+
f"Found > 1 dirs in dcp checkpoint dir {checkpoint_dir} "
514478
f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
515479
)
516480
else:
517481
# then take it as a safetensors checkpoint
518482
# - do not support .bin checkpoints
519-
checkpoint_dir = args.checkpoint_dir
520483
loader = get_state_dict_from_safe_checkpoint
521484

522485
# - pretrained model name
523-
_name_or_path = args.pretrained_model_name_or_path
486+
_name_or_path = pretrained_model_name_or_path
524487

525488
# assume output directory exists, we do not create it
526489
# - copy the config file if exists
527490
config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
528-
target_config_file = os.path.join(args.output_dir, CONFIG_NAME)
491+
target_config_file = os.path.join(output_dir, CONFIG_NAME)
529492
if os.path.exists(config_file):
530493
shutil.copyfile(config_file, target_config_file)
531494

@@ -544,6 +507,46 @@ def save_sharded_safetensors(
544507
# save it as a safetensors file
545508
save_sharded_safetensors(
546509
{k: v.contiguous() for k, v in state_dict.items()},
547-
args.output_dir,
510+
output_dir,
548511
metadata={"format": "pt"},
549512
)
513+
514+
515+
# have it serve as a conversion script
516+
if __name__ == "__main__":
517+
# Standard
518+
import argparse
519+
520+
parser = argparse.ArgumentParser(
521+
description=(
522+
"Utility for converting ScatterMoE checkpoint back to the "
523+
"orginal state dict format. "
524+
"The ScatterMoE checkpoint was saved after the pretrained model "
525+
"had been converted by a module swap, hence the state dict will "
526+
"no longer resemble the original. This utility creaes"
527+
)
528+
)
529+
530+
parser.add_argument(
531+
"checkpoint_dir",
532+
help="Path to the checkpoint.",
533+
)
534+
535+
parser.add_argument(
536+
"output_dir", help="Path to the location to write the converted checkpoint."
537+
)
538+
539+
parser.add_argument(
540+
"pretrained_model_name_or_path",
541+
help=(
542+
"In order to reconstruct the state dict, we requre hints from "
543+
"the original pretrained model checkpoint (from which this "
544+
"checkpoint is obtained)."
545+
),
546+
default=None,
547+
)
548+
549+
args = parser.parse_args()
550+
recover_safetensors_from_dcp(
551+
args.checkpoint_dir, args.pretrained_model_name_or_path, args.output_dir
552+
)

plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
from fms_acceleration.model_patcher import patch_target_module
2828
from peft import LoraConfig, prepare_model_for_kbit_training
2929
from peft.tuners.lora.model import LoraModel
30-
from transformers import AutoModelForCausalLM, TrainingArguments
30+
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
3131
from transformers.modeling_utils import is_fsdp_enabled
3232
from transformers.utils.import_utils import _is_package_available
33+
from accelerate import Accelerator
3334
import torch
3435
import torch.distributed
3536

@@ -355,7 +356,7 @@ def augmentation(
355356
return model, modifiable_args
356357

357358
def get_callbacks_and_ready_for_train(
358-
self, model: torch.nn.Module = None, accelerator=None
359+
self, model: torch.nn.Module = None, accelerator: Accelerator = None, trainer: Trainer = None, pretrained_module_name_or_path: str = None
359360
):
360361
callbacks = []
361362
if (

plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
# Third Party
2525
from fms_acceleration import AccelerationPlugin
2626
from peft import LoraConfig, get_peft_model
27-
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
27+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer
2828
from transformers.utils.import_utils import _is_package_available
29+
from accelerate import Accelerator
2930
import torch
3031

3132
# Local
@@ -218,7 +219,7 @@ def augmentation(
218219
return model, modifiable_args
219220

220221
def get_callbacks_and_ready_for_train(
221-
self, model: torch.nn.Module = None, accelerator=None
222+
self, model: torch.nn.Module = None, accelerator: Accelerator = None, trainer: Trainer = None, pretrained_module_name_or_path: str = None
222223
):
223224
callbacks = []
224225
if (

plugins/framework/src/fms_acceleration/framework.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# Third Party
2020
from accelerate import Accelerator
21-
from transformers import PreTrainedModel, TrainingArguments
21+
from transformers import PreTrainedModel, TrainingArguments, Trainer
2222
from transformers.utils import logging
2323
from transformers.utils.import_utils import _is_package_available
2424
import torch
@@ -218,7 +218,7 @@ def requires_augmentation(self):
218218
return any(x.requires_augmentation for _, x in self.active_plugins)
219219

220220
def get_callbacks_and_ready_for_train(
221-
self, model: torch.nn.Module = None, accelerator: Accelerator = None
221+
self, model: torch.nn.Module = None, accelerator: Accelerator = None, trainer: Trainer = None, pretrained_module_name_or_path: str = None
222222
):
223223

224224
# Local
@@ -257,5 +257,5 @@ def get_callbacks_and_ready_for_train(
257257

258258
cbks = []
259259
for _, plugin in self.active_plugins:
260-
cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator))
260+
cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator, trainer, pretrained_module_name_or_path))
261261
return cbks

plugins/framework/src/fms_acceleration/framework_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Third Party
2222
from accelerate import Accelerator
2323
from peft import LoraConfig
24-
from transformers import TrainingArguments
24+
from transformers import TrainingArguments, Trainer
2525
import torch
2626

2727

@@ -186,7 +186,7 @@ def augmentation(
186186
raise NotImplementedError
187187

188188
def get_callbacks_and_ready_for_train(
189-
self, model: torch.nn.Module = None, accelerator: Accelerator = None
189+
self, model: torch.nn.Module = None, accelerator: Accelerator = None, trainer: Trainer = None, pretrained_module_name_or_path: str = None
190190
):
191191
return []
192192

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from fms_acceleration import AccelerationPlugin, AccelerationPluginConfigError
2020
from peft import LoraConfig
2121
from peft.tuners.lora.layer import LoraLayer
22-
from transformers import PretrainedConfig, TrainingArguments
22+
from transformers import PretrainedConfig, TrainingArguments, Trainer
23+
from accelerate import Accelerator
2324
import torch
2425

2526
# Local
@@ -184,7 +185,7 @@ def augmentation(
184185
return model, modifiable_args
185186

186187
def get_callbacks_and_ready_for_train(
187-
self, model: torch.nn.Module = None, accelerator=None
188+
self, model: torch.nn.Module = None, accelerator: Accelerator = None, trainer: Trainer = None, pretrained_module_name_or_path: str = None
188189
):
189190
# This callback applies only for qpeft
190191
# should not install this for full FT and standard peft

0 commit comments

Comments
 (0)