Skip to content

Commit c3aa25c

Browse files
authored
feat: support moe hf chkpt (#486)
* feat: support moe hf chkpt Signed-off-by: Mehant Kammakomati <[email protected]> * feat: move callbacks to fms-hf-tuning Signed-off-by: Mehant Kammakomati <[email protected]> * feat: update docs Signed-off-by: Mehant Kammakomati <[email protected]> * feat: add test case for moe saving hf checkpoint Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 7c218d2 commit c3aa25c

File tree

6 files changed

+145
-1
lines changed

6 files changed

+145
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ Notes:
804804
* Notes on Fast MoE
805805
- `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree).
806806
- `world_size` must be divisible by the `ep_degree`
807-
- Running fast moe modifies the state dict of the model, and must be post-processed using [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) to run inference (HF, vLLM, etc.).
807+
- Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script.
808808
- The typical usecase for this script is to run:
809809
```
810810
python -m fms_acceleration_moe.utils.checkpoint_utils \

tests/test_sft_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
# Local
7474
from tuning import sft_trainer
7575
from tuning.config import configs, peft_config
76+
from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig
7677
from tuning.config.tracker_configs import FileLoggingTrackerConfig
7778
from tuning.data.data_config import (
7879
DataConfig,
@@ -85,6 +86,7 @@
8586
DataHandlerType,
8687
add_tokenizer_eos_token,
8788
)
89+
from tuning.utils.import_utils import is_fms_accelerate_available
8890

8991
MODEL_ARGS = configs.ModelArguments(
9092
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
@@ -1336,6 +1338,36 @@ def test_run_e2e_with_hf_dataset_id(data_args):
13361338
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))
13371339

13381340

1341+
@pytest.mark.skipif(
1342+
not is_fms_accelerate_available(plugins="moe"),
1343+
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
1344+
)
1345+
@pytest.mark.parametrize(
1346+
"dataset_path",
1347+
[
1348+
TWITTER_COMPLAINTS_DATA_JSONL,
1349+
],
1350+
)
1351+
def test_run_moe_ft_and_inference(dataset_path):
1352+
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
1353+
with tempfile.TemporaryDirectory() as tempdir:
1354+
data_args = copy.deepcopy(DATA_ARGS)
1355+
data_args.training_data_path = dataset_path
1356+
model_args = copy.deepcopy(MODEL_ARGS)
1357+
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
1358+
train_args = copy.deepcopy(TRAIN_ARGS)
1359+
train_args.output_dir = tempdir
1360+
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
1361+
sft_trainer.train(
1362+
model_args, data_args, train_args, fast_moe_config=fast_moe_config
1363+
)
1364+
_test_run_inference(
1365+
checkpoint_path=os.path.join(
1366+
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
1367+
)
1368+
)
1369+
1370+
13391371
############################# Helper functions #############################
13401372
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
13411373
train_args = copy.deepcopy(training_args)

tuning/config/acceleration_configs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Local
1616
from .acceleration_framework_config import AccelerationFrameworkConfig
1717
from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig
18+
from .callbacks import get_additional_accel_framework_callbacks
1819
from .fast_moe import FastMoeConfig
1920
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
2021
from .quantized_lora_config import QuantizedLoraConfig
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Local
2+
from .fast_moe import get_callbacks
3+
4+
5+
def get_additional_accel_framework_callbacks(active_plugins, **kwargs):
6+
callbacks = []
7+
for active_plugin in active_plugins:
8+
if "ScatterMoEAccelerationPlugin" == active_plugin[0]:
9+
callbacks.extend(get_callbacks(**kwargs))
10+
return callbacks

tuning/config/acceleration_configs/fast_moe.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,30 @@
1414

1515
# Standard
1616
from dataclasses import dataclass
17+
import os
18+
19+
# Third Party
20+
from transformers import (
21+
Trainer,
22+
TrainerCallback,
23+
TrainerControl,
24+
TrainerState,
25+
TrainingArguments,
26+
)
27+
from transformers.trainer import TRAINING_ARGS_NAME
28+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
29+
import torch
1730

1831
# Local
1932
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass
2033

34+
is_recover_safetensors_from_dcp_available = True
35+
try:
36+
# Third Party
37+
from fms_acceleration_moe.utils import recover_safetensors_from_dcp
38+
except ImportError:
39+
is_recover_safetensors_from_dcp_available = False
40+
2141

2242
@parsable_dataclass
2343
@dataclass
@@ -34,3 +54,77 @@ class FastMoeConfig:
3454
def __post_init__(self):
3555
# ensure nested dataclasses initialized
3656
ensure_nested_dataclasses_initialized(self)
57+
58+
59+
def get_callbacks(**kwargs):
60+
pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path")
61+
trainer = kwargs.pop("trainer")
62+
callbacks = []
63+
if is_recover_safetensors_from_dcp_available:
64+
65+
class ConvertAndSaveHFCheckpointAtEverySave(TrainerCallback):
66+
def __init__(self, pretrained_model_name_or_path: str, trainer: Trainer):
67+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
68+
self.trainer = trainer
69+
70+
def on_save(
71+
self,
72+
args: TrainingArguments,
73+
state: TrainerState,
74+
control: TrainerControl,
75+
**kwargs,
76+
):
77+
"""
78+
Save all HF files and convert dcp checkpoint to safetensors at every save operation.
79+
"""
80+
81+
def checkpoint():
82+
checkpoint_dir = os.path.join(
83+
args.output_dir,
84+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
85+
)
86+
hf_converted_output_dir = os.path.join(
87+
checkpoint_dir, "hf_converted_checkpoint"
88+
)
89+
if os.path.exists(hf_converted_output_dir):
90+
# if the folder already exists
91+
# we return, since this is possible to happen
92+
# saving the checkpointing at the end of the training
93+
return
94+
os.mkdir(hf_converted_output_dir)
95+
try:
96+
recover_safetensors_from_dcp(
97+
checkpoint_dir,
98+
self.pretrained_model_name_or_path,
99+
hf_converted_output_dir,
100+
)
101+
# save tokenizer
102+
if self.trainer.processing_class:
103+
self.trainer.processing_class.save_pretrained(
104+
hf_converted_output_dir
105+
)
106+
# save training args
107+
torch.save(
108+
args,
109+
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
110+
)
111+
# save model config files
112+
self.trainer.model.config.save_pretrained(
113+
hf_converted_output_dir
114+
)
115+
116+
except Exception as e:
117+
raise ValueError(
118+
f"Failed to convert the checkpoint {checkpoint_dir}\
119+
to a HF compatible checkpoint"
120+
) from e
121+
122+
if state.is_world_process_zero:
123+
checkpoint()
124+
125+
callbacks.append(
126+
ConvertAndSaveHFCheckpointAtEverySave(
127+
pretrained_model_name_or_path, trainer
128+
)
129+
)
130+
return callbacks

tuning/sft_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
FastMoeConfig,
4949
FusedOpsAndKernelsConfig,
5050
QuantizedLoraConfig,
51+
get_additional_accel_framework_callbacks,
5152
)
5253
from tuning.config.tracker_configs import (
5354
AimConfig,
@@ -411,6 +412,12 @@ def train(
411412
# ready for train may produce additional callbacks for the trainer
412413
for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
413414
trainer.add_callback(x)
415+
for clb in get_additional_accel_framework_callbacks(
416+
active_plugins=framework.active_plugins,
417+
trainer=trainer,
418+
pretrained_model_name_or_path=model_args.model_name_or_path,
419+
):
420+
trainer.add_callback(clb)
414421

415422
resume_from_checkpoint = None
416423
# Check if resume flag is not passed (None), or if flag is true and

0 commit comments

Comments
 (0)