diff --git a/README.md b/README.md index d81f6823f..7fc675ee3 100644 --- a/README.md +++ b/README.md @@ -804,7 +804,7 @@ Notes: * Notes on Fast MoE - `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree). - `world_size` must be divisible by the `ep_degree` - - Running fast moe modifies the state dict of the model, and must be post-processed using [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) to run inference (HF, vLLM, etc.). + - Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script. - The typical usecase for this script is to run: ``` python -m fms_acceleration_moe.utils.checkpoint_utils \ diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index a002a0955..156b1b43e 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -73,6 +73,7 @@ # Local from tuning import sft_trainer from tuning.config import configs, peft_config +from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig from tuning.config.tracker_configs import FileLoggingTrackerConfig from tuning.data.data_config import ( DataConfig, @@ -85,6 +86,7 @@ DataHandlerType, add_tokenizer_eos_token, ) +from tuning.utils.import_utils import is_fms_accelerate_available MODEL_ARGS = configs.ModelArguments( model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" @@ -1336,6 +1338,36 @@ def test_run_e2e_with_hf_dataset_id(data_args): _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="moe"), + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", +) +@pytest.mark.parametrize( + "dataset_path", + [ + TWITTER_COMPLAINTS_DATA_JSONL, + ], +) +def test_run_moe_ft_and_inference(dataset_path): + """Check if we can finetune a moe model and check if hf checkpoint is created""" + with tempfile.TemporaryDirectory() as tempdir: + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = dataset_path + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) + sft_trainer.train( + model_args, data_args, train_args, fast_moe_config=fast_moe_config + ) + _test_run_inference( + checkpoint_path=os.path.join( + _get_checkpoint_path(tempdir), "hf_converted_checkpoint" + ) + ) + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args) diff --git a/tuning/config/acceleration_configs/__init__.py b/tuning/config/acceleration_configs/__init__.py index 98be34240..35d425c46 100644 --- a/tuning/config/acceleration_configs/__init__.py +++ b/tuning/config/acceleration_configs/__init__.py @@ -15,6 +15,7 @@ # Local from .acceleration_framework_config import AccelerationFrameworkConfig from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig +from .callbacks import get_additional_accel_framework_callbacks from .fast_moe import FastMoeConfig from .fused_ops_and_kernels import FusedOpsAndKernelsConfig from .quantized_lora_config import QuantizedLoraConfig diff --git a/tuning/config/acceleration_configs/callbacks.py b/tuning/config/acceleration_configs/callbacks.py new file mode 100644 index 000000000..4571bc7fd --- /dev/null +++ b/tuning/config/acceleration_configs/callbacks.py @@ -0,0 +1,10 @@ +# Local +from .fast_moe import get_callbacks + + +def get_additional_accel_framework_callbacks(active_plugins, **kwargs): + callbacks = [] + for active_plugin in active_plugins: + if "ScatterMoEAccelerationPlugin" == active_plugin[0]: + callbacks.extend(get_callbacks(**kwargs)) + return callbacks diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 14a44f929..6c618e7d7 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -14,10 +14,30 @@ # Standard from dataclasses import dataclass +import os + +# Third Party +from transformers import ( + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from transformers.trainer import TRAINING_ARGS_NAME +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +import torch # Local from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass +is_recover_safetensors_from_dcp_available = True +try: + # Third Party + from fms_acceleration_moe.utils import recover_safetensors_from_dcp +except ImportError: + is_recover_safetensors_from_dcp_available = False + @parsable_dataclass @dataclass @@ -34,3 +54,77 @@ class FastMoeConfig: def __post_init__(self): # ensure nested dataclasses initialized ensure_nested_dataclasses_initialized(self) + + +def get_callbacks(**kwargs): + pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path") + trainer = kwargs.pop("trainer") + callbacks = [] + if is_recover_safetensors_from_dcp_available: + + class ConvertAndSaveHFCheckpointAtEverySave(TrainerCallback): + def __init__(self, pretrained_model_name_or_path: str, trainer: Trainer): + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.trainer = trainer + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + Save all HF files and convert dcp checkpoint to safetensors at every save operation. + """ + + def checkpoint(): + checkpoint_dir = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + hf_converted_output_dir = os.path.join( + checkpoint_dir, "hf_converted_checkpoint" + ) + if os.path.exists(hf_converted_output_dir): + # if the folder already exists + # we return, since this is possible to happen + # saving the checkpointing at the end of the training + return + os.mkdir(hf_converted_output_dir) + try: + recover_safetensors_from_dcp( + checkpoint_dir, + self.pretrained_model_name_or_path, + hf_converted_output_dir, + ) + # save tokenizer + if self.trainer.processing_class: + self.trainer.processing_class.save_pretrained( + hf_converted_output_dir + ) + # save training args + torch.save( + args, + os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME), + ) + # save model config files + self.trainer.model.config.save_pretrained( + hf_converted_output_dir + ) + + except Exception as e: + raise ValueError( + f"Failed to convert the checkpoint {checkpoint_dir}\ + to a HF compatible checkpoint" + ) from e + + if state.is_world_process_zero: + checkpoint() + + callbacks.append( + ConvertAndSaveHFCheckpointAtEverySave( + pretrained_model_name_or_path, trainer + ) + ) + return callbacks diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 03248ab99..3cd490d7d 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -48,6 +48,7 @@ FastMoeConfig, FusedOpsAndKernelsConfig, QuantizedLoraConfig, + get_additional_accel_framework_callbacks, ) from tuning.config.tracker_configs import ( AimConfig, @@ -411,6 +412,12 @@ def train( # ready for train may produce additional callbacks for the trainer for x in framework.get_callbacks_and_ready_for_train(model, accelerator): trainer.add_callback(x) + for clb in get_additional_accel_framework_callbacks( + active_plugins=framework.active_plugins, + trainer=trainer, + pretrained_model_name_or_path=model_args.model_name_or_path, + ): + trainer.add_callback(clb) resume_from_checkpoint = None # Check if resume flag is not passed (None), or if flag is true and