Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ Notes:
* Notes on Fast MoE
- `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree).
- `world_size` must be divisible by the `ep_degree`
- Running fast moe modifies the state dict of the model, and must be post-processed using [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) to run inference (HF, vLLM, etc.).
- Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script.
- The typical usecase for this script is to run:
```
python -m fms_acceleration_moe.utils.checkpoint_utils \
Expand Down
32 changes: 32 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
# Local
from tuning import sft_trainer
from tuning.config import configs, peft_config
from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig
from tuning.config.tracker_configs import FileLoggingTrackerConfig
from tuning.data.data_config import (
DataConfig,
Expand All @@ -85,6 +86,7 @@
DataHandlerType,
add_tokenizer_eos_token,
)
from tuning.utils.import_utils import is_fms_accelerate_available

MODEL_ARGS = configs.ModelArguments(
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
Expand Down Expand Up @@ -1336,6 +1338,36 @@ def test_run_e2e_with_hf_dataset_id(data_args):
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
)
@pytest.mark.parametrize(
"dataset_path",
[
TWITTER_COMPLAINTS_DATA_JSONL,
],
)
def test_run_moe_ft_and_inference(dataset_path):
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
with tempfile.TemporaryDirectory() as tempdir:
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = dataset_path
model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
sft_trainer.train(
model_args, data_args, train_args, fast_moe_config=fast_moe_config
)
_test_run_inference(
checkpoint_path=os.path.join(
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
)
)


############################# Helper functions #############################
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
train_args = copy.deepcopy(training_args)
Expand Down
1 change: 1 addition & 0 deletions tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Local
from .acceleration_framework_config import AccelerationFrameworkConfig
from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig
from .callbacks import get_additional_accel_framework_callbacks
from .fast_moe import FastMoeConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
from .quantized_lora_config import QuantizedLoraConfig
10 changes: 10 additions & 0 deletions tuning/config/acceleration_configs/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Local
from .fast_moe import get_callbacks


def get_additional_accel_framework_callbacks(active_plugins, **kwargs):
callbacks = []
for active_plugin in active_plugins:
if "ScatterMoEAccelerationPlugin" == active_plugin[0]:
callbacks.extend(get_callbacks(**kwargs))
return callbacks
94 changes: 94 additions & 0 deletions tuning/config/acceleration_configs/fast_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,30 @@

# Standard
from dataclasses import dataclass
import os

# Third Party
from transformers import (
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import torch

# Local
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass

is_recover_safetensors_from_dcp_available = True
try:
# Third Party
from fms_acceleration_moe.utils import recover_safetensors_from_dcp
except ImportError:
is_recover_safetensors_from_dcp_available = False


@parsable_dataclass
@dataclass
Expand All @@ -34,3 +54,77 @@ class FastMoeConfig:
def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)


def get_callbacks(**kwargs):
pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path")
trainer = kwargs.pop("trainer")
callbacks = []
if is_recover_safetensors_from_dcp_available:

class ConvertAndSaveHFCheckpointAtEverySave(TrainerCallback):
def __init__(self, pretrained_model_name_or_path: str, trainer: Trainer):
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.trainer = trainer

def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""
Save all HF files and convert dcp checkpoint to safetensors at every save operation.
"""

def checkpoint():
checkpoint_dir = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
hf_converted_output_dir = os.path.join(
checkpoint_dir, "hf_converted_checkpoint"
)
if os.path.exists(hf_converted_output_dir):
# if the folder already exists
# we return, since this is possible to happen
# saving the checkpointing at the end of the training
return
os.mkdir(hf_converted_output_dir)
try:
recover_safetensors_from_dcp(
checkpoint_dir,
self.pretrained_model_name_or_path,
hf_converted_output_dir,
)
# save tokenizer
if self.trainer.processing_class:
self.trainer.processing_class.save_pretrained(
hf_converted_output_dir
)
# save training args
torch.save(
args,
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
)
# save model config files
self.trainer.model.config.save_pretrained(
hf_converted_output_dir
)

except Exception as e:
raise ValueError(
f"Failed to convert the checkpoint {checkpoint_dir}\
to a HF compatible checkpoint"
) from e

if state.is_world_process_zero:
checkpoint()

callbacks.append(
ConvertAndSaveHFCheckpointAtEverySave(
pretrained_model_name_or_path, trainer
)
)
return callbacks
7 changes: 7 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
FastMoeConfig,
FusedOpsAndKernelsConfig,
QuantizedLoraConfig,
get_additional_accel_framework_callbacks,
)
from tuning.config.tracker_configs import (
AimConfig,
Expand Down Expand Up @@ -411,6 +412,12 @@ def train(
# ready for train may produce additional callbacks for the trainer
for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
trainer.add_callback(x)
for clb in get_additional_accel_framework_callbacks(
active_plugins=framework.active_plugins,
trainer=trainer,
pretrained_model_name_or_path=model_args.model_name_or_path,
):
trainer.add_callback(clb)

resume_from_checkpoint = None
# Check if resume flag is not passed (None), or if flag is true and
Expand Down