diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index cdce01b40..f70003494 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -170,24 +170,14 @@ def train( if fast_moe_config is not None and fast_moe_config.fast_moe is None: fast_moe_config = None if fast_moe_config is not None: - # Checking for unsupported modules with Scatter MoE for LoRA - # Only raise an error for `all-linear` - restricted_modules = ["all-linear"] + # If LoRA with ScatterMoE detected, raise warning + accepted_layers = ["all-linear"] if ( peft_config is not None and hasattr(peft_config, "target_modules") - and any( - module in (peft_config.target_modules or []) - for module in restricted_modules - ) + and fast_moe_config.fast_moe is not None + and peft_config.target_modules != accepted_layers ): - raise ValueError( - "`--fast_moe` with LoRA does not currently support `all-linear`, as " - "target modules at this time. Please explicitly specify target " - "modules when using `--fast_moe` with LoRA." - ) - # If other common non-linear modules, raise warning - if peft_config is not None and hasattr(peft_config, "target_modules"): logger.warning( "You are running lora with the ScatterMoE plugin, please note that " "passing target modules that are part of the moe module can cause unexpected "