|
73 | 73 | # Local |
74 | 74 | from tuning import sft_trainer |
75 | 75 | from tuning.config import configs, peft_config |
| 76 | +from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig |
76 | 77 | from tuning.config.tracker_configs import FileLoggingTrackerConfig |
77 | 78 | from tuning.data.data_config import ( |
78 | 79 | DataConfig, |
|
85 | 86 | DataHandlerType, |
86 | 87 | add_tokenizer_eos_token, |
87 | 88 | ) |
| 89 | +from tuning.utils.import_utils import is_fms_accelerate_available |
88 | 90 |
|
89 | 91 | MODEL_ARGS = configs.ModelArguments( |
90 | 92 | model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" |
@@ -1336,6 +1338,36 @@ def test_run_e2e_with_hf_dataset_id(data_args): |
1336 | 1338 | _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) |
1337 | 1339 |
|
1338 | 1340 |
|
| 1341 | +@pytest.mark.skipif( |
| 1342 | + not is_fms_accelerate_available(plugins="moe"), |
| 1343 | + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", |
| 1344 | +) |
| 1345 | +@pytest.mark.parametrize( |
| 1346 | + "dataset_path", |
| 1347 | + [ |
| 1348 | + TWITTER_COMPLAINTS_DATA_JSONL, |
| 1349 | + ], |
| 1350 | +) |
| 1351 | +def test_run_moe_ft_and_inference(dataset_path): |
| 1352 | + """Check if we can finetune a moe model and check if hf checkpoint is created""" |
| 1353 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1354 | + data_args = copy.deepcopy(DATA_ARGS) |
| 1355 | + data_args.training_data_path = dataset_path |
| 1356 | + model_args = copy.deepcopy(MODEL_ARGS) |
| 1357 | + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" |
| 1358 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1359 | + train_args.output_dir = tempdir |
| 1360 | + fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) |
| 1361 | + sft_trainer.train( |
| 1362 | + model_args, data_args, train_args, fast_moe_config=fast_moe_config |
| 1363 | + ) |
| 1364 | + _test_run_inference( |
| 1365 | + checkpoint_path=os.path.join( |
| 1366 | + _get_checkpoint_path(tempdir), "hf_converted_checkpoint" |
| 1367 | + ) |
| 1368 | + ) |
| 1369 | + |
| 1370 | + |
1339 | 1371 | ############################# Helper functions ############################# |
1340 | 1372 | def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): |
1341 | 1373 | train_args = copy.deepcopy(training_args) |
|
0 commit comments