diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 037c0420f..28e0f024e 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1377,6 +1377,34 @@ def test_run_moe_ft_and_inference(dataset_path): ) +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="moe"), + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", +) +@pytest.mark.parametrize( + "dataset_path", + [ + TWITTER_COMPLAINTS_DATA_JSONL, + ], +) +def test_run_moe_ft_with_save_model_dir(dataset_path): + """Check if we can finetune a moe model and check if hf checkpoint is created""" + with tempfile.TemporaryDirectory() as tempdir: + save_model_dir = os.path.join(tempdir, "save_model") + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = dataset_path + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_model_dir = save_model_dir + fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) + sft_trainer.train( + model_args, data_args, train_args, fast_moe_config=fast_moe_config + ) + assert os.path.exists(os.path.join(save_model_dir, "hf_converted_checkpoint")) + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 6c618e7d7..f36fbf4c3 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -59,13 +59,20 @@ def __post_init__(self): def get_callbacks(**kwargs): pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path") trainer = kwargs.pop("trainer") + save_model_dir = kwargs.pop("save_model_dir") callbacks = [] if is_recover_safetensors_from_dcp_available: class ConvertAndSaveHFCheckpointAtEverySave(TrainerCallback): - def __init__(self, pretrained_model_name_or_path: str, trainer: Trainer): + def __init__( + self, + pretrained_model_name_or_path: str, + trainer: Trainer, + save_model_dir: str, + ): self.pretrained_model_name_or_path = pretrained_model_name_or_path self.trainer = trainer + self.save_model_dir = save_model_dir def on_save( self, @@ -76,18 +83,15 @@ def on_save( ): """ Save all HF files and convert dcp checkpoint to safetensors at every save operation. + Also saves the final model in save_model_dir if provided. """ - def checkpoint(): - checkpoint_dir = os.path.join( - args.output_dir, - f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", - ) + def checkpoint(checkpoint_dir, save_dir): hf_converted_output_dir = os.path.join( - checkpoint_dir, "hf_converted_checkpoint" + save_dir, "hf_converted_checkpoint" ) if os.path.exists(hf_converted_output_dir): - # if the folder already exists + # If the folder already exists # we return, since this is possible to happen # saving the checkpointing at the end of the training return @@ -98,17 +102,17 @@ def checkpoint(): self.pretrained_model_name_or_path, hf_converted_output_dir, ) - # save tokenizer + # Save tokenizer if self.trainer.processing_class: self.trainer.processing_class.save_pretrained( hf_converted_output_dir ) - # save training args + # Save training args torch.save( args, os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME), ) - # save model config files + # Save model config files self.trainer.model.config.save_pretrained( hf_converted_output_dir ) @@ -116,15 +120,28 @@ def checkpoint(): except Exception as e: raise ValueError( f"Failed to convert the checkpoint {checkpoint_dir}\ - to a HF compatible checkpoint" + to a HF compatible checkpoint in {save_dir}" ) from e if state.is_world_process_zero: - checkpoint() + # Save periodic checkpoint + checkpoint_dir = os.path.join( + args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" + ) + checkpoint(checkpoint_dir, checkpoint_dir) + + # If final save directory is provided, save the model there + if ( + getattr(self, "save_model_dir", None) + and state.global_step == state.max_steps + ): + if not os.path.exists(self.save_model_dir): + os.mkdir(self.save_model_dir) + checkpoint(checkpoint_dir, self.save_model_dir) callbacks.append( ConvertAndSaveHFCheckpointAtEverySave( - pretrained_model_name_or_path, trainer + pretrained_model_name_or_path, trainer, save_model_dir ) ) return callbacks diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 196fa63d8..1d090540e 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -417,6 +417,7 @@ def train( active_plugins=framework.active_plugins, trainer=trainer, pretrained_model_name_or_path=model_args.model_name_or_path, + save_model_dir=train_args.save_model_dir, ): trainer.add_callback(clb)