Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,24 +170,14 @@ def train(
if fast_moe_config is not None and fast_moe_config.fast_moe is None:
fast_moe_config = None
if fast_moe_config is not None:
# Checking for unsupported modules with Scatter MoE for LoRA
# Only raise an error for `all-linear`
restricted_modules = ["all-linear"]
# If LoRA with ScatterMoE detected, raise warning
accepted_layers = ["all-linear"]
if (
peft_config is not None
and hasattr(peft_config, "target_modules")
and any(
module in (peft_config.target_modules or [])
for module in restricted_modules
)
and fast_moe_config.fast_moe is not None
and peft_config.target_modules != accepted_layers
):
raise ValueError(
"`--fast_moe` with LoRA does not currently support `all-linear`, as "
"target modules at this time. Please explicitly specify target "
"modules when using `--fast_moe` with LoRA."
)
# If other common non-linear modules, raise warning
if peft_config is not None and hasattr(peft_config, "target_modules"):
logger.warning(
"You are running lora with the ScatterMoE plugin, please note that "
"passing target modules that are part of the moe module can cause unexpected "
Expand Down