-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Loading a flux lora, using delete_adapters then loading another lora is basically a guaranteed key error because delete_adapters leaves the state dict inconsistent with what load_lora_weights expects.
Possibly related to #11003 but that was opened for 0.32.X which does not exhibit the specific issue described here.
Reproduction
import torch
from diffusers import FluxPipeline
from safetensors.torch import load_file
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
aname = "bml"
sd = load_file("/tmp/bml.safetensors") # https://civitai.com/models/743448?modelVersionId=831442
pipe.load_lora_weights(sd, adapter_name=aname)
pipe.delete_adapters([aname])
# # This still works? But dropping every adapter unconditionally isn't desirable.
# pipe.unload_lora_weights()
pipe.load_lora_weights(sd, adapter_name=aname)Logs
Traceback (most recent call last):
File "/tmp/f1d.py", line 12, in <module>
pipe.load_lora_weights(sd, adapter_name=aname)
File "/tmp/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py", line 1847, in load_lora_weights
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py", line 2436, in _maybe_expand_lora_state_dict
base_weight_param = transformer_state_dict[base_param_name]
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
KeyError: 'single_transformer_blocks.0.attn.to_k.weight'System Info
- π€ Diffusers version: 0.33.1
- Platform: Linux-6.14.6-arch1-1-x86_64-with-glibc2.41
- Running on Google Colab?: No
- Python version: 3.12.10
- PyTorch version (GPU?): 2.7.0+rocm6.3 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.31.4
- Transformers version: 4.52.0
- Accelerate version: 1.7.0
- PEFT version: 0.15.2
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NA
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Who can help?
sayakpaul
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working