-
Couldn't load subscription status.
- Fork 6.5k
Description
Describe the bug
Proposal to update the following script for Xlab Flux LoRA conversion due to a mismatch between keys in the state dictionary.
src/diffusers/loaders/lora_conversion_utils.py
When mapping single_blocks layers, if the model trained in Flux contains single_blocks, these keys are not updated and removed from the old_state_dict, see lines 635-655. And the ValueError is reached:
if len(old_state_dict) > 0:
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
See example, keys from Flux LoRA model working (XLabs-AI/flux-RealismLora), it doesn’t contain single_blocks:
['double_blocks.0.processor.proj_lora1.down.weight', 'double_blocks.0.processor.proj_lora1.up.weight', 'double_blocks.0.processor.proj_lora2.down.weight', 'double_blocks.0.processor.proj_lora2.up.weight', 'double_blocks.0.processor.qkv_lora1.down.weight', 'double_blocks.0.processor.qkv_lora1.up.weight', 'double_blocks.0.processor.qkv_lora2.down.weight', 'double_blocks.0.processor.qkv_lora2.up.weight', 'double_blocks.1.processor.proj_lora1.down.weight', 'double_blocks.1.processor.proj_lora1.up.weight', 'double_blocks.1.processor.proj_lora2.down.weight', 'double_blocks.1.processor.proj_lora2.up.weight', 'double_blocks.1.processor.qkv_lora1.down.weight', 'double_blocks.1.processor.qkv_lora1.up.weight', 'double_blocks.1.processor.qkv_lora2.down.weight', 'double_blocks.1.processor.qkv_lora2.up.weight', 'double_blocks.10.processor.proj_lora1.down.weight', 'double_blocks.10.processor.proj_lora1.up.weight', 'double_blocks.10.processor.proj_lora2.down.weight', 'double_blocks.10.processor.proj_lora2.up.weight', 'double_blocks.10.processor.qkv_lora1.down.weight', 'double_blocks.10.processor.qkv_lora1.up.weight', 'double_blocks.10.processor.qkv_lora2.down.weight', 'double_blocks.10.processor.qkv_lora2.up.weight', 'double_blocks.11.processor.proj_lora1.down.weight', 'double_blocks.11.processor.proj_lora1.up.weight', 'double_blocks.11.processor.proj_lora2.down.weight', 'double_blocks.11.processor.proj_lora2.up.weight', 'double_blocks.11.processor.qkv_lora1.down.weight', 'double_blocks.11.processor.qkv_lora1.up.weight', 'double_blocks.11.processor.qkv_lora2.down.weight', 'double_blocks.11.processor.qkv_lora2.up.weight', 'double_blocks.12.processor.proj_lora1.down.weight', 'double_blocks.12.processor.proj_lora1.up.weight', 'double_blocks.12.processor.proj_lora2.down.weight', 'double_blocks.12.processor.proj_lora2.up.weight', 'double_blocks.12.processor.qkv_lora1.down.weight', 'double_blocks.12.processor.qkv_lora1.up.weight', 'double_blocks.12.processor.qkv_lora2.down.weight', 'double_blocks.12.processor.qkv_lora2.up.weight', 'double_blocks.13.processor.proj_lora1.down.weight', 'double_blocks.13.processor.proj_lora1.up.weight', 'double_blocks.13.processor.proj_lora2.down.weight', 'double_blocks.13.processor.proj_lora2.up.weight', 'double_blocks.13.processor.qkv_lora1.down.weight', 'double_blocks.13.processor.qkv_lora1.up.weight', 'double_blocks.13.processor.qkv_lora2.down.weight', 'double_blocks.13.processor.qkv_lora2.up.weight', 'double_blocks.14.processor.proj_lora1.down.weight', 'double_blocks.14.processor.proj_lora1.up.weight', 'double_blocks.14.processor.proj_lora2.down.weight', 'double_blocks.14.processor.proj_lora2.up.weight', 'double_blocks.14.processor.qkv_lora1.down.weight', 'double_blocks.14.processor.qkv_lora1.up.weight', 'double_blocks.14.processor.qkv_lora2.down.weight', 'double_blocks.14.processor.qkv_lora2.up.weight', 'double_blocks.15.processor.proj_lora1.down.weight', 'double_blocks.15.processor.proj_lora1.up.weight', 'double_blocks.15.processor.proj_lora2.down.weight', 'double_blocks.15.processor.proj_lora2.up.weight', 'double_blocks.15.processor.qkv_lora1.down.weight', 'double_blocks.15.processor.qkv_lora1.up.weight', 'double_blocks.15.processor.qkv_lora2.down.weight', 'double_blocks.15.processor.qkv_lora2.up.weight', 'double_blocks.16.processor.proj_lora1.down.weight', 'double_blocks.16.processor.proj_lora1.up.weight', 'double_blocks.16.processor.proj_lora2.down.weight', 'double_blocks.16.processor.proj_lora2.up.weight', 'double_blocks.16.processor.qkv_lora1.down.weight', 'double_blocks.16.processor.qkv_lora1.up.weight', 'double_blocks.16.processor.qkv_lora2.down.weight', 'double_blocks.16.processor.qkv_lora2.up.weight', 'double_blocks.17.processor.proj_lora1.down.weight', 'double_blocks.17.processor.proj_lora1.up.weight', 'double_blocks.17.processor.proj_lora2.down.weight', 'double_blocks.17.processor.proj_lora2.up.weight', 'double_blocks.17.processor.qkv_lora1.down.weight', 'double_blocks.17.processor.qkv_lora1.up.weight', 'double_blocks.17.processor.qkv_lora2.down.weight', 'double_blocks.17.processor.qkv_lora2.up.weight', 'double_blocks.18.processor.proj_lora1.down.weight', 'double_blocks.18.processor.proj_lora1.up.weight', 'double_blocks.18.processor.proj_lora2.down.weight', 'double_blocks.18.processor.proj_lora2.up.weight', 'double_blocks.18.processor.qkv_lora1.down.weight', 'double_blocks.18.processor.qkv_lora1.up.weight', 'double_blocks.18.processor.qkv_lora2.down.weight', 'double_blocks.18.processor.qkv_lora2.up.weight', 'double_blocks.2.processor.proj_lora1.down.weight', 'double_blocks.2.processor.proj_lora1.up.weight', 'double_blocks.2.processor.proj_lora2.down.weight', 'double_blocks.2.processor.proj_lora2.up.weight', 'double_blocks.2.processor.qkv_lora1.down.weight', 'double_blocks.2.processor.qkv_lora1.up.weight', 'double_blocks.2.processor.qkv_lora2.down.weight', 'double_blocks.2.processor.qkv_lora2.up.weight', 'double_blocks.3.processor.proj_lora1.down.weight', 'double_blocks.3.processor.proj_lora1.up.weight', 'double_blocks.3.processor.proj_lora2.down.weight', 'double_blocks.3.processor.proj_lora2.up.weight', 'double_blocks.3.processor.qkv_lora1.down.weight', 'double_blocks.3.processor.qkv_lora1.up.weight', 'double_blocks.3.processor.qkv_lora2.down.weight', 'double_blocks.3.processor.qkv_lora2.up.weight', 'double_blocks.4.processor.proj_lora1.down.weight', 'double_blocks.4.processor.proj_lora1.up.weight', 'double_blocks.4.processor.proj_lora2.down.weight', 'double_blocks.4.processor.proj_lora2.up.weight', 'double_blocks.4.processor.qkv_lora1.down.weight', 'double_blocks.4.processor.qkv_lora1.up.weight', 'double_blocks.4.processor.qkv_lora2.down.weight', 'double_blocks.4.processor.qkv_lora2.up.weight', 'double_blocks.5.processor.proj_lora1.down.weight', 'double_blocks.5.processor.proj_lora1.up.weight', 'double_blocks.5.processor.proj_lora2.down.weight', 'double_blocks.5.processor.proj_lora2.up.weight', 'double_blocks.5.processor.qkv_lora1.down.weight', 'double_blocks.5.processor.qkv_lora1.up.weight', 'double_blocks.5.processor.qkv_lora2.down.weight', 'double_blocks.5.processor.qkv_lora2.up.weight', 'double_blocks.6.processor.proj_lora1.down.weight', 'double_blocks.6.processor.proj_lora1.up.weight', 'double_blocks.6.processor.proj_lora2.down.weight', 'double_blocks.6.processor.proj_lora2.up.weight', 'double_blocks.6.processor.qkv_lora1.down.weight', 'double_blocks.6.processor.qkv_lora1.up.weight', 'double_blocks.6.processor.qkv_lora2.down.weight', 'double_blocks.6.processor.qkv_lora2.up.weight', 'double_blocks.7.processor.proj_lora1.down.weight', 'double_blocks.7.processor.proj_lora1.up.weight', 'double_blocks.7.processor.proj_lora2.down.weight', 'double_blocks.7.processor.proj_lora2.up.weight', 'double_blocks.7.processor.qkv_lora1.down.weight', 'double_blocks.7.processor.qkv_lora1.up.weight', 'double_blocks.7.processor.qkv_lora2.down.weight', 'double_blocks.7.processor.qkv_lora2.up.weight', 'double_blocks.8.processor.proj_lora1.down.weight', 'double_blocks.8.processor.proj_lora1.up.weight', 'double_blocks.8.processor.proj_lora2.down.weight', 'double_blocks.8.processor.proj_lora2.up.weight', 'double_blocks.8.processor.qkv_lora1.down.weight', 'double_blocks.8.processor.qkv_lora1.up.weight', 'double_blocks.8.processor.qkv_lora2.down.weight', 'double_blocks.8.processor.qkv_lora2.up.weight', 'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight']
And below an example of a LoRA trained with current Xlabs code containing single_blocks:
['double_blocks.0.processor.proj_lora1.down.weight', 'double_blocks.0.processor.proj_lora1.up.weight', 'double_blocks.0.processor.proj_lora2.down.weight', 'double_blocks.0.processor.proj_lora2.up.weight', 'double_blocks.0.processor.qkv_lora1.down.weight', 'double_blocks.0.processor.qkv_lora1.up.weight', 'double_blocks.0.processor.qkv_lora2.down.weight', 'double_blocks.0.processor.qkv_lora2.up.weight', 'double_blocks.1.processor.proj_lora1.down.weight', 'double_blocks.1.processor.proj_lora1.up.weight', 'double_blocks.1.processor.proj_lora2.down.weight', 'double_blocks.1.processor.proj_lora2.up.weight', 'double_blocks.1.processor.qkv_lora1.down.weight', 'double_blocks.1.processor.qkv_lora1.up.weight', 'double_blocks.1.processor.qkv_lora2.down.weight', 'double_blocks.1.processor.qkv_lora2.up.weight', 'double_blocks.10.processor.proj_lora1.down.weight', 'double_blocks.10.processor.proj_lora1.up.weight', 'double_blocks.10.processor.proj_lora2.down.weight', 'double_blocks.10.processor.proj_lora2.up.weight', 'double_blocks.10.processor.qkv_lora1.down.weight', 'double_blocks.10.processor.qkv_lora1.up.weight', 'double_blocks.10.processor.qkv_lora2.down.weight', 'double_blocks.10.processor.qkv_lora2.up.weight', 'double_blocks.11.processor.proj_lora1.down.weight', 'double_blocks.11.processor.proj_lora1.up.weight', 'double_blocks.11.processor.proj_lora2.down.weight', 'double_blocks.11.processor.proj_lora2.up.weight', 'double_blocks.11.processor.qkv_lora1.down.weight', 'double_blocks.11.processor.qkv_lora1.up.weight', 'double_blocks.11.processor.qkv_lora2.down.weight', 'double_blocks.11.processor.qkv_lora2.up.weight', 'double_blocks.12.processor.proj_lora1.down.weight', 'double_blocks.12.processor.proj_lora1.up.weight', 'double_blocks.12.processor.proj_lora2.down.weight', 'double_blocks.12.processor.proj_lora2.up.weight', 'double_blocks.12.processor.qkv_lora1.down.weight', 'double_blocks.12.processor.qkv_lora1.up.weight', 'double_blocks.12.processor.qkv_lora2.down.weight', 'double_blocks.12.processor.qkv_lora2.up.weight', 'double_blocks.13.processor.proj_lora1.down.weight', 'double_blocks.13.processor.proj_lora1.up.weight', 'double_blocks.13.processor.proj_lora2.down.weight', 'double_blocks.13.processor.proj_lora2.up.weight', 'double_blocks.13.processor.qkv_lora1.down.weight', 'double_blocks.13.processor.qkv_lora1.up.weight', 'double_blocks.13.processor.qkv_lora2.down.weight', 'double_blocks.13.processor.qkv_lora2.up.weight', 'double_blocks.14.processor.proj_lora1.down.weight', 'double_blocks.14.processor.proj_lora1.up.weight', 'double_blocks.14.processor.proj_lora2.down.weight', 'double_blocks.14.processor.proj_lora2.up.weight', 'double_blocks.14.processor.qkv_lora1.down.weight', 'double_blocks.14.processor.qkv_lora1.up.weight', 'double_blocks.14.processor.qkv_lora2.down.weight', 'double_blocks.14.processor.qkv_lora2.up.weight', 'double_blocks.15.processor.proj_lora1.down.weight', 'double_blocks.15.processor.proj_lora1.up.weight', 'double_blocks.15.processor.proj_lora2.down.weight', 'double_blocks.15.processor.proj_lora2.up.weight', 'double_blocks.15.processor.qkv_lora1.down.weight', 'double_blocks.15.processor.qkv_lora1.up.weight', 'double_blocks.15.processor.qkv_lora2.down.weight', 'double_blocks.15.processor.qkv_lora2.up.weight', 'double_blocks.16.processor.proj_lora1.down.weight', 'double_blocks.16.processor.proj_lora1.up.weight', 'double_blocks.16.processor.proj_lora2.down.weight', 'double_blocks.16.processor.proj_lora2.up.weight', 'double_blocks.16.processor.qkv_lora1.down.weight', 'double_blocks.16.processor.qkv_lora1.up.weight', 'double_blocks.16.processor.qkv_lora2.down.weight', 'double_blocks.16.processor.qkv_lora2.up.weight', 'double_blocks.17.processor.proj_lora1.down.weight', 'double_blocks.17.processor.proj_lora1.up.weight', 'double_blocks.17.processor.proj_lora2.down.weight', 'double_blocks.17.processor.proj_lora2.up.weight', 'double_blocks.17.processor.qkv_lora1.down.weight', 'double_blocks.17.processor.qkv_lora1.up.weight', 'double_blocks.17.processor.qkv_lora2.down.weight', 'double_blocks.17.processor.qkv_lora2.up.weight', 'double_blocks.18.processor.proj_lora1.down.weight', 'double_blocks.18.processor.proj_lora1.up.weight', 'double_blocks.18.processor.proj_lora2.down.weight', 'double_blocks.18.processor.proj_lora2.up.weight', 'double_blocks.18.processor.qkv_lora1.down.weight', 'double_blocks.18.processor.qkv_lora1.up.weight', 'double_blocks.18.processor.qkv_lora2.down.weight', 'double_blocks.18.processor.qkv_lora2.up.weight', 'double_blocks.2.processor.proj_lora1.down.weight', 'double_blocks.2.processor.proj_lora1.up.weight', 'double_blocks.2.processor.proj_lora2.down.weight', 'double_blocks.2.processor.proj_lora2.up.weight', 'double_blocks.2.processor.qkv_lora1.down.weight', 'double_blocks.2.processor.qkv_lora1.up.weight', 'double_blocks.2.processor.qkv_lora2.down.weight', 'double_blocks.2.processor.qkv_lora2.up.weight', 'double_blocks.3.processor.proj_lora1.down.weight', 'double_blocks.3.processor.proj_lora1.up.weight', 'double_blocks.3.processor.proj_lora2.down.weight', 'double_blocks.3.processor.proj_lora2.up.weight', 'double_blocks.3.processor.qkv_lora1.down.weight', 'double_blocks.3.processor.qkv_lora1.up.weight', 'double_blocks.3.processor.qkv_lora2.down.weight', 'double_blocks.3.processor.qkv_lora2.up.weight', 'double_blocks.4.processor.proj_lora1.down.weight', 'double_blocks.4.processor.proj_lora1.up.weight', 'double_blocks.4.processor.proj_lora2.down.weight', 'double_blocks.4.processor.proj_lora2.up.weight', 'double_blocks.4.processor.qkv_lora1.down.weight', 'double_blocks.4.processor.qkv_lora1.up.weight', 'double_blocks.4.processor.qkv_lora2.down.weight', 'double_blocks.4.processor.qkv_lora2.up.weight', 'double_blocks.5.processor.proj_lora1.down.weight', 'double_blocks.5.processor.proj_lora1.up.weight', 'double_blocks.5.processor.proj_lora2.down.weight', 'double_blocks.5.processor.proj_lora2.up.weight', 'double_blocks.5.processor.qkv_lora1.down.weight', 'double_blocks.5.processor.qkv_lora1.up.weight', 'double_blocks.5.processor.qkv_lora2.down.weight', 'double_blocks.5.processor.qkv_lora2.up.weight', 'double_blocks.6.processor.proj_lora1.down.weight', 'double_blocks.6.processor.proj_lora1.up.weight', 'double_blocks.6.processor.proj_lora2.down.weight', 'double_blocks.6.processor.proj_lora2.up.weight', 'double_blocks.6.processor.qkv_lora1.down.weight', 'double_blocks.6.processor.qkv_lora1.up.weight', 'double_blocks.6.processor.qkv_lora2.down.weight', 'double_blocks.6.processor.qkv_lora2.up.weight', 'double_blocks.7.processor.proj_lora1.down.weight', 'double_blocks.7.processor.proj_lora1.up.weight', 'double_blocks.7.processor.proj_lora2.down.weight', 'double_blocks.7.processor.proj_lora2.up.weight', 'double_blocks.7.processor.qkv_lora1.down.weight', 'double_blocks.7.processor.qkv_lora1.up.weight', 'double_blocks.7.processor.qkv_lora2.down.weight', 'double_blocks.7.processor.qkv_lora2.up.weight', 'double_blocks.8.processor.proj_lora1.down.weight', 'double_blocks.8.processor.proj_lora1.up.weight', 'double_blocks.8.processor.proj_lora2.down.weight', 'double_blocks.8.processor.proj_lora2.up.weight', 'double_blocks.8.processor.qkv_lora1.down.weight', 'double_blocks.8.processor.qkv_lora1.up.weight', 'double_blocks.8.processor.qkv_lora2.down.weight', 'double_blocks.8.processor.qkv_lora2.up.weight', 'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight', 'single_blocks.1.processor.proj_lora.down.weight', 'single_blocks.1.processor.proj_lora.up.weight', 'single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.proj_lora.down.weight', 'single_blocks.2.processor.proj_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.proj_lora.down.weight', 'single_blocks.3.processor.proj_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight']
The script works changing lines 639-642 by:
if "proj_lora" in old_key:
new_key += ".proj_out"
elif "qkv_lora" in old_key and "up" not in old_key:
handle_qkv(old_state_dict, new_state_dict, old_key, [
f"transformer.single_transformer_blocks.{block_num}.norm.linear"
])
Related PR #9295 (@sayakpaul )
Reproduction
import torch
from diffusers import DiffusionPipeline
model_path = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
lora_model_path = "XLabs-AI/flslux-RealismLora"
# lora_model_path = "<PATH-LoRA-trained-Xlabs.safetensors>"
pipe.load_lora_weights(lora_model_path, adapter_name="lora_A")
Logs
When a custom LoRA trained with Xlabs code containing single_blocks is loaded:
File "/home/.pyenv/versions/xflux/lib/python3.10/site-packages/diffusers/loaders/lora_conversion_utils.py", line 658, in _convert_xlabs_flux_lora_to_diffusers
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
ValueError: `old_state_dict` should be at this point but has: ['single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'].System Info
- 🤗 Diffusers version: 0.31.0
- Platform: Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.4.0+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.24.5
- Transformers version: 4.43.3
- Accelerate version: 0.30.1
- PEFT version: 0.13.2
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?: