|
21 | 21 | except ImportError: |
22 | 22 | grouped_gemm = None |
23 | 23 |
|
| 24 | +try: |
| 25 | + import transformer_engine # noqa: F401 |
| 26 | + HAS_TE = True |
| 27 | +except ImportError: |
| 28 | + HAS_TE = False |
| 29 | + |
24 | 30 | from nemo_automodel.components.moe.config import MoEConfig |
25 | 31 | from nemo_automodel.components.moe.layers import GroupedExperts, GroupedExpertsDeepEP, GroupedExpertsTE |
26 | 32 | from nemo_automodel.components._peft.lora_experts import GroupedExpertsLoRA, GroupedExpertsDeepEPLoRA |
@@ -608,22 +614,28 @@ def test_deepep_lora_zero_tokens(moe_config, device): |
608 | 614 |
|
609 | 615 |
|
610 | 616 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") |
| 617 | +@pytest.mark.skipif(not HAS_TE, reason="Transformer Engine required") |
611 | 618 | def test_patch_moe_module_rejects_te_experts(moe_config, device): |
612 | 619 | """Test that patch_moe_module raises NotImplementedError for GroupedExpertsTE.""" |
613 | | - orig_experts = GroupedExpertsTE(moe_config).to(device) |
| 620 | + orig_experts = GroupedExpertsTE(moe_config) |
| 621 | + orig_experts.init_weights(buffer_device=device) |
| 622 | + orig_experts = orig_experts.to(device) |
614 | 623 | with pytest.raises(NotImplementedError, match="LoRA is not supported for Transformer Engine"): |
615 | 624 | patch_moe_module(orig_experts, dim=4) |
616 | 625 |
|
617 | 626 |
|
618 | 627 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") |
| 628 | +@pytest.mark.skipif(not HAS_TE, reason="Transformer Engine required") |
619 | 629 | def test_apply_lora_rejects_te_experts(moe_config, device): |
620 | 630 | """Test that apply_lora_to_linear_modules raises NotImplementedError for GroupedExpertsTE.""" |
621 | 631 | class MockModel(nn.Module): |
622 | 632 | def __init__(self): |
623 | 633 | super().__init__() |
624 | 634 | self.experts = GroupedExpertsTE(moe_config) |
625 | 635 |
|
626 | | - model = MockModel().to(device) |
| 636 | + model = MockModel() |
| 637 | + model.experts.init_weights(buffer_device=device) |
| 638 | + model = model.to(device) |
627 | 639 | peft_config = PeftConfig(target_modules=["experts"], dim=4) |
628 | 640 |
|
629 | 641 | with pytest.raises(NotImplementedError, match="LoRA is not supported for Transformer Engine"): |
@@ -841,11 +853,6 @@ def __init__(self): |
841 | 853 | dim=16, |
842 | 854 | moe_rank_scaling=True, |
843 | 855 | ) |
844 | | - import logging as _logging |
845 | | - |
846 | | - with pytest.warns(None) as _: |
847 | | - # We check the logger instead of pytest.warns (logger.warning, not warnings.warn) |
848 | | - pass |
849 | 856 |
|
850 | 857 | with patch("nemo_automodel.components._peft.lora.logger") as mock_logger: |
851 | 858 | count = apply_lora_to_linear_modules(model, peft_config) |
|
0 commit comments