Skip to content

Commit 5f128d4

Browse files
committed
feat: add test case for moe saving hf checkpoint
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 72e9533 commit 5f128d4

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/test_sft_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
# Local
7474
from tuning import sft_trainer
7575
from tuning.config import configs, peft_config
76+
from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig
7677
from tuning.config.tracker_configs import FileLoggingTrackerConfig
7778
from tuning.data.data_config import (
7879
DataConfig,
@@ -85,6 +86,7 @@
8586
DataHandlerType,
8687
add_tokenizer_eos_token,
8788
)
89+
from tuning.utils.import_utils import is_fms_accelerate_available
8890

8991
MODEL_ARGS = configs.ModelArguments(
9092
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
@@ -1336,6 +1338,36 @@ def test_run_e2e_with_hf_dataset_id(data_args):
13361338
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))
13371339

13381340

1341+
@pytest.mark.skipif(
1342+
not is_fms_accelerate_available(plugins="moe"),
1343+
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
1344+
)
1345+
@pytest.mark.parametrize(
1346+
"dataset_path",
1347+
[
1348+
TWITTER_COMPLAINTS_DATA_JSONL,
1349+
],
1350+
)
1351+
def test_run_moe_ft_and_inference(dataset_path):
1352+
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
1353+
with tempfile.TemporaryDirectory() as tempdir:
1354+
data_args = copy.deepcopy(DATA_ARGS)
1355+
data_args.training_data_path = dataset_path
1356+
model_args = copy.deepcopy(MODEL_ARGS)
1357+
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
1358+
train_args = copy.deepcopy(TRAIN_ARGS)
1359+
train_args.output_dir = tempdir
1360+
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
1361+
sft_trainer.train(
1362+
model_args, data_args, train_args, fast_moe_config=fast_moe_config
1363+
)
1364+
_test_run_inference(
1365+
checkpoint_path=os.path.join(
1366+
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
1367+
)
1368+
)
1369+
1370+
13391371
############################# Helper functions #############################
13401372
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
13411373
train_args = copy.deepcopy(training_args)

0 commit comments

Comments
 (0)