-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Description
Reproduction
Summary
DPOTrainer's ref adapter creation (model.add_adapter("ref", ...)) crashes when the LoRA config uses target_parameters — a PEFT feature required for training MoE models on Transformers 5.x, where expert modules are fused nn.Parameter tensors instead of individual nn.Linear layers.
PEFT currently restricts multiple adapters with target_parameters to one per model (peft#2710), so TRL's attempt to create a second "ref" adapter always fails.
Environment
transformers==5.3.0trl==0.29.0peft==0.18.1- Model: Qwen3-30B-A3B
Root Cause Chain
1. Transformers 5.x changed MoE architecture
In Transformers 4.x, each MoE expert was a separate nn.Module with individual nn.Linear layers:
model.layers.0.mlp.experts.0.gate_proj → nn.Linear
model.layers.0.mlp.experts.0.up_proj → nn.Linear
model.layers.0.mlp.experts.0.down_proj → nn.Linear
In Transformers 5.x, experts are fused into a single module with stacked nn.Parameter tensors:
model.layers.0.mlp.experts → Qwen3MoeExperts (single module)
model.layers.0.mlp.experts.gate_up_proj → nn.Parameter, shape [128, 1536, 2048]
model.layers.0.mlp.experts.down_proj → nn.Parameter, shape [128, 2048, 768]
This means PEFT's target_modules (which matches nn.Module names) can no longer target expert layers. Users must use target_parameters instead (introduced in peft#2498), which targets nn.Parameter objects directly.
2. PEFT restricts multi-adapter with target_parameters
PEFT intentionally blocks creating multiple adapters that use target_parameters (peft#2710) due to unresolved issues with nested .base_layer, state_dict corruption, and load-order dependencies.
3. TRL DPOTrainer creates a "ref" adapter
DPOTrainer creates a frozen copy of the "default" adapter for computing reference logprobs:
# trl/trainer/dpo_trainer.py
if is_peft_available() and is_peft_model(model) and ref_model is None:
model.add_adapter("ref", model.peft_config["default"]) # <-- crashes hereSince model.peft_config["default"] contains target_parameters, PEFT blocks this second adapter.
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer, DPOConfig
from datasets import Dataset
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-30B-A3B",
torch_dtype="bfloat16",
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B")
# LoRA with target_parameters for fused MoE experts
lora_config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
target_parameters=["mlp.experts.gate_up_proj", "mlp.experts.down_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# trainable params: 1,285,029,888 (4.04%) -- confirms expert LoRA is applied
# Minimal DPO dataset
data = [{"prompt": "Hello", "chosen": "Hi!", "rejected": "Bye!"}]
dataset = Dataset.from_list(data)
dpo_config = DPOConfig(output_dir="/tmp/dpo", max_length=512)
# This crashes:
trainer = DPOTrainer(
model=model,
ref_model=None,
args=dpo_config,
train_dataset=dataset,
processing_class=tokenizer,
)Error
ValueError: Adding a LoRA config with `target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj']`
but there are already other LoRA adapters on this model that use `target_parameters`.
At the moment, only one LoRA adapter per model with `target_parameters` is allowed.
Full traceback:
File "trl/trainer/dpo_trainer.py", line 535, in __init__
model.add_adapter("ref", model.peft_config["default"])
File "peft/peft_model.py", line 1056, in add_adapter
self.base_model.inject_adapter(...)
File "peft/tuners/tuners_utils.py", line 804, in inject_adapter
self._create_and_replace(...)
File "peft/tuners/lora/model.py", line 180, in _create_and_replace
raise ValueError(...)
Impact
This blocks LoRA fine-tuning (both SFT and DPO) of all MoE models on Transformers 5.x where experts are fused. This includes Qwen3-30B-A3B, Qwen3.5-35B-A3B, and likely future MoE models that use the fused expert pattern.
Users are forced to either:
- Downgrade to Transformers 4.x (which uses unfused per-expert
nn.Linearand is significantly slower due to 128 sequential small matmuls instead of one batched operation) - Use a workaround that monkey-patches
model.add_adapterto skip the ref adapter creation
Suggested Fix
When the default adapter uses target_parameters, skip ref adapter creation and fall back to the disable_adapter() approach for computing reference logprobs. This is functionally identical — at initialization all LoRA weights are zero, so both approaches produce base model logprobs.
# In DPOTrainer.__init__:
if is_peft_available() and is_peft_model(model) and ref_model is None:
default_config = model.peft_config["default"]
if getattr(default_config, "target_parameters", None):
# Fall back to disable_adapter() for reference logprobs.
# PEFT doesn't support multiple adapters with target_parameters.
pass
else:
model.add_adapter("ref", default_config)
for name, param in model.named_parameters():
if ".default." in name:
ref_name = name.replace(".default.", ".ref.")
ref_param = model.get_parameter(ref_name)
ref_param.data.copy_(param.data)The loss computation already handles this gracefully:
# Already in DPOTrainer:
with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None):
# adapter_name=None disables the adapter → base model logprobsRelated Issues
- PEFT: peft#2710 — multi-adapter
target_parametersrestriction - PEFT: peft#2498 —
target_parametersfeature introduction - Transformers: Fused MoE expert modules in
Qwen3MoeExperts(Transformers 5.x)
Current Workaround
We monkey-patch model.add_adapter and model.get_parameter during DPOTrainer.__init__ to silently skip the ref adapter creation and absorb the weight copying. TRL then falls back to use_adapter(model, adapter_name=None) which disables the adapter for reference logprobs.
@contextmanager
def _patch_ref_adapter_for_target_parameters(model):
has_target_params = (
hasattr(model, "peft_config")
and "default" in getattr(model, "peft_config", {})
and getattr(model.peft_config["default"], "target_parameters", None)
)
if not has_target_params:
yield
return
_orig_add_adapter = model.add_adapter
_orig_get_parameter = model.get_parameter
def _skip_ref_add_adapter(name, *args, **kwargs):
if name == "ref":
return
return _orig_add_adapter(name, *args, **kwargs)
class _NoopParam:
class data:
@staticmethod
def copy_(*args):
pass
def _safe_get_parameter(name):
if ".ref." in name:
return _NoopParam()
return _orig_get_parameter(name)
model.add_adapter = _skip_ref_add_adapter
model.get_parameter = _safe_get_parameter
try:
yield
finally:
model.add_adapter = _orig_add_adapter
model.get_parameter = _orig_get_parameter
# Usage:
with _patch_ref_adapter_for_target_parameters(model):
trainer = DPOTrainer(model=model, ...)System Info
- Platform: Linux-6.8.0-83-generic-x86_64-with-glibc2.35
- Python version: 3.12.13
- TRL version: 0.29.0
- PyTorch version: 2.10.0
- accelerator(s): NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200
- Transformers version: 5.3.0
- Accelerate version: 1.13.0
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 4
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'deepspeed_moe_layer_cls_names': 'Qwen3MoeSparseMoeBlock', 'gradient_accumulation_steps': 8, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero3_save_16bit_model': True, 'zero_stage': 3}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 4.6.1
- HF Hub version: 1.5.0
- bitsandbytes version: not installed
- DeepSpeed version: 0.18.6
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: 2.24.0
- PEFT version: 0.18.1
- vLLM version: not installed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete