|
43 | 43 | MultiPack, |
44 | 44 | PaddingFree, |
45 | 45 | ) |
| 46 | +from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig |
46 | 47 | from tuning.config.acceleration_configs.fused_ops_and_kernels import ( |
47 | 48 | FastKernelsConfig, |
48 | 49 | FusedLoraConfig, |
|
56 | 57 | # for some reason the CI will raise an import error if we try to import |
57 | 58 | # these from tests.artifacts.testdata |
58 | 59 | TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( |
59 | | - os.path.dirname(__file__), "../artifacts/testdata/twitter_complaints_json.json" |
| 60 | + os.path.dirname(__file__), |
| 61 | + "../artifacts/testdata/json/twitter_complaints_small.json", |
60 | 62 | ) |
61 | 63 | TWITTER_COMPLAINTS_TOKENIZED = os.path.join( |
62 | 64 | os.path.dirname(__file__), |
|
87 | 89 | # Third Party |
88 | 90 | from fms_acceleration_aadp import PaddingFreeAccelerationPlugin |
89 | 91 |
|
| 92 | + if is_fms_accelerate_available(plugins="moe"): |
| 93 | + # Third Party |
| 94 | + from fms_acceleration_moe import ScatterMoEAccelerationPlugin |
| 95 | + |
90 | 96 |
|
91 | 97 | # There are more extensive unit tests in the |
92 | 98 | # https://github.com/foundation-model-stack/fms-acceleration |
@@ -360,7 +366,7 @@ def test_framework_raises_due_to_invalid_arguments( |
360 | 366 | acceleration_configs_map, |
361 | 367 | ids=["bitsandbytes", "auto_gptq"], |
362 | 368 | ) |
363 | | -def test_framework_intialized_properly_peft( |
| 369 | +def test_framework_initialized_properly_peft( |
364 | 370 | quantized_lora_config, model_name_or_path, mock_and_spy |
365 | 371 | ): |
366 | 372 | """Ensure that specifying a properly configured acceleration dataclass |
@@ -412,7 +418,7 @@ def test_framework_intialized_properly_peft( |
412 | 418 | "and foak plugins" |
413 | 419 | ), |
414 | 420 | ) |
415 | | -def test_framework_intialized_properly_foak(): |
| 421 | +def test_framework_initialized_properly_foak(): |
416 | 422 | """Ensure that specifying a properly configured acceleration dataclass |
417 | 423 | properly activates the framework plugin and runs the train sucessfully. |
418 | 424 | """ |
@@ -477,6 +483,60 @@ def test_framework_intialized_properly_foak(): |
477 | 483 | assert spy2["get_ready_for_train_calls"] == 1 |
478 | 484 |
|
479 | 485 |
|
| 486 | +@pytest.mark.skipif( |
| 487 | + not is_fms_accelerate_available(plugins="moe"), |
| 488 | + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", |
| 489 | +) |
| 490 | +def test_framework_initialized_properly_moe(): |
| 491 | + """Ensure that specifying a properly configured acceleration dataclass |
| 492 | + properly activates the framework plugin and runs the train sucessfully. |
| 493 | + """ |
| 494 | + |
| 495 | + with tempfile.TemporaryDirectory() as tempdir: |
| 496 | + |
| 497 | + model_args = copy.deepcopy(MODEL_ARGS) |
| 498 | + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" |
| 499 | + model_args.torch_dtype = torch.bfloat16 |
| 500 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 501 | + train_args.output_dir = tempdir |
| 502 | + train_args.save_strategy = "no" |
| 503 | + train_args.bf16 = True |
| 504 | + data_args = copy.deepcopy(DATA_ARGS) |
| 505 | + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT |
| 506 | + data_args.response_template = "\n\n### Label:" |
| 507 | + data_args.dataset_text_field = "output" |
| 508 | + |
| 509 | + # initialize a config |
| 510 | + moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) |
| 511 | + |
| 512 | + # create mocked plugin class for spying |
| 513 | + MockedPlugin1, spy = create_mock_plugin_class_and_spy( |
| 514 | + "FastMoeMock", ScatterMoEAccelerationPlugin |
| 515 | + ) |
| 516 | + |
| 517 | + # 1. mock a plugin class |
| 518 | + # 2. register the mocked plugins |
| 519 | + # 3. call sft_trainer.train |
| 520 | + with build_framework_and_maybe_instantiate( |
| 521 | + [ |
| 522 | + (["training.moe.scattermoe"], MockedPlugin1), |
| 523 | + ], |
| 524 | + instantiate=False, |
| 525 | + ): |
| 526 | + with instantiate_model_patcher(): |
| 527 | + sft_trainer.train( |
| 528 | + model_args, |
| 529 | + data_args, |
| 530 | + train_args, |
| 531 | + fast_moe_config=moe_config, |
| 532 | + ) |
| 533 | + |
| 534 | + # spy inside the train to ensure that the ilab plugin is called |
| 535 | + assert spy["model_loader_calls"] == 1 |
| 536 | + assert spy["augmentation_calls"] == 0 |
| 537 | + assert spy["get_ready_for_train_calls"] == 1 |
| 538 | + |
| 539 | + |
480 | 540 | @pytest.mark.skipif( |
481 | 541 | not is_fms_accelerate_available(plugins="aadp"), |
482 | 542 | reason="Only runs if fms-accelerate is installed along with \ |
@@ -661,6 +721,100 @@ def test_error_raised_with_fused_lora_enabled_without_quantized_argument(): |
661 | 721 | ) |
662 | 722 |
|
663 | 723 |
|
| 724 | +@pytest.mark.skipif( |
| 725 | + not is_fms_accelerate_available(plugins="moe"), |
| 726 | + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", |
| 727 | +) |
| 728 | +def test_error_raised_with_undividable_fastmoe_argument(): |
| 729 | + """ |
| 730 | + Ensure error is thrown when `--fast_moe` is passed and world_size |
| 731 | + is not divisible by ep_degree |
| 732 | + """ |
| 733 | + with pytest.raises( |
| 734 | + AssertionError, match="world size \\(1\\) not divisible by ep_size \\(3\\)" |
| 735 | + ): |
| 736 | + with tempfile.TemporaryDirectory() as tempdir: |
| 737 | + |
| 738 | + model_args = copy.deepcopy(MODEL_ARGS) |
| 739 | + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" |
| 740 | + model_args.torch_dtype = torch.bfloat16 |
| 741 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 742 | + train_args.output_dir = tempdir |
| 743 | + train_args.save_strategy = "no" |
| 744 | + train_args.bf16 = True |
| 745 | + data_args = copy.deepcopy(DATA_ARGS) |
| 746 | + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT |
| 747 | + data_args.response_template = "\n\n### Label:" |
| 748 | + data_args.dataset_text_field = "output" |
| 749 | + |
| 750 | + # initialize a config |
| 751 | + moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=3)) |
| 752 | + |
| 753 | + # 1. mock a plugin class |
| 754 | + # 2. register the mocked plugins |
| 755 | + # 3. call sft_trainer.train |
| 756 | + with build_framework_and_maybe_instantiate( |
| 757 | + [ |
| 758 | + (["training.moe.scattermoe"], ScatterMoEAccelerationPlugin), |
| 759 | + ], |
| 760 | + instantiate=False, |
| 761 | + ): |
| 762 | + with instantiate_model_patcher(): |
| 763 | + sft_trainer.train( |
| 764 | + model_args, |
| 765 | + data_args, |
| 766 | + train_args, |
| 767 | + fast_moe_config=moe_config, |
| 768 | + ) |
| 769 | + |
| 770 | + |
| 771 | +@pytest.mark.skipif( |
| 772 | + not is_fms_accelerate_available(plugins="moe"), |
| 773 | + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", |
| 774 | +) |
| 775 | +def test_error_raised_fast_moe_with_non_moe_model(): |
| 776 | + """ |
| 777 | + Ensure error is thrown when `--fast_moe` is passed and model is not MoE |
| 778 | + """ |
| 779 | + with pytest.raises( |
| 780 | + AttributeError, |
| 781 | + match="'LlamaConfig' object has no attribute 'num_local_experts'", |
| 782 | + ): |
| 783 | + with tempfile.TemporaryDirectory() as tempdir: |
| 784 | + |
| 785 | + model_args = copy.deepcopy(MODEL_ARGS) |
| 786 | + model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" |
| 787 | + model_args.torch_dtype = torch.bfloat16 |
| 788 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 789 | + train_args.output_dir = tempdir |
| 790 | + train_args.save_strategy = "no" |
| 791 | + train_args.bf16 = True |
| 792 | + data_args = copy.deepcopy(DATA_ARGS) |
| 793 | + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT |
| 794 | + data_args.response_template = "\n\n### Label:" |
| 795 | + data_args.dataset_text_field = "output" |
| 796 | + |
| 797 | + # initialize a config |
| 798 | + moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) |
| 799 | + |
| 800 | + # 1. mock a plugin class |
| 801 | + # 2. register the mocked plugins |
| 802 | + # 3. call sft_trainer.train |
| 803 | + with build_framework_and_maybe_instantiate( |
| 804 | + [ |
| 805 | + (["training.moe.scattermoe"], ScatterMoEAccelerationPlugin), |
| 806 | + ], |
| 807 | + instantiate=False, |
| 808 | + ): |
| 809 | + with instantiate_model_patcher(): |
| 810 | + sft_trainer.train( |
| 811 | + model_args, |
| 812 | + data_args, |
| 813 | + train_args, |
| 814 | + fast_moe_config=moe_config, |
| 815 | + ) |
| 816 | + |
| 817 | + |
664 | 818 | @pytest.mark.skipif( |
665 | 819 | not is_fms_accelerate_available(plugins="foak"), |
666 | 820 | reason="Only runs if fms-accelerate is installed along with \ |
|
0 commit comments