- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Open
Labels
Description
Describe the bug
When trying to hotswap multiple flux loras you get a runtime error around unexpected keys
RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.13.norm1.linear.lora_B.weight,
Reproduction
Download two Flux Dev loras (this example uses http://base-weights.weights.com/cm9dm38e4061uon15341k47ss.zip and http://base-weights.weights.com/cm9dnj1840088n214rn9uych4.zip)
Unzip and load the safetensors into memory
import time
import torch
import logging
from diffusers import FluxPipeline
logger = logging.get_logger(__name__)
class DownloadedLora:
    def __init__(self, state_dict):
        self.state_dict = state_dict
    @property
    def model(self):
        state_dict = self.state_dict
        # return a clone
        # of the state dict to avoid modifying the original
        new_state_dict = {}
        for k, v in state_dict.items():
            new_state_dict[k] = v.clone().detach()
        return new_state_dict
def test_lora_hotswap():
    logger.info(f"Initializing flux model")
    # todo - compile https://github.com/huggingface/diffusers/pull/9453 when this gets merged
    flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    )
    flux_base_model = flux_base_model.to("cuda")
    flux_base_model.enable_lora_hotswap(target_rank=128)
    # download and set the state dicts of two random loras
    first_lora = DownloadedLora(state_dict=first_state_dict)
    second_lora = DownloadedLora(state_dict=second_state_dict)
    # we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
    # these will then be enabled or disabled
    flux_base_model.load_lora_weights(first_lora.model, adapter_name="1")
    flux_base_model.load_lora_weights(second_lora.model, adapter_name="2")
    flux_base_model.load_lora_weights(second_lora.model, adapter_name="3")
    logger.info("Initialized base flux model")
    should_compile = False
    if should_compile:
        flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
        flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
        flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
        flux_base_model.vae = torch.compile(flux_base_model.vae)
        flux_base_model.transformer = torch.compile(
            flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
        )
    for i in range(5):
        start_time = time.time()
        image = flux_base_model("An image of a cat", num_inference_steps=4, guidance_scale=3.0).images[0]
        if i == 0:
            logger.info(f"Warmup: {time.time() - start_time}")
        else:
            logger.info(f"Inference time: {time.time() - start_time}")
        utc_seconds = int(time.time())
        image.save(f"hotswap_{utc_seconds}.png")
        if i == 1:
            logger.info("Hotswapping lora one")
            flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
        if i == 2:
            logger.info("Hotswapping lora two")
            flux_base_model.load_lora_weights(second_lora.model, adapter_name="2", hotswap=True)
            flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)Logs
2025-04-12 04:47:18 | INFO     | Initialized base flux model
100%|ββββββββββ| 4/4 [00:01<00:00,  3.64it/s]
2025-04-12 04:47:21 | INFO     | Warmup: 2.4211995601654053
100%|ββββββββββ| 4/4 [00:01<00:00,  3.79it/s]
2025-04-12 04:47:23 | INFO     | Inference time: 1.2886595726013184
2025-04-12 04:47:23 | INFO     | Hotswapping lora one
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/team/replay/python/hosted/utils/testing.py", line 708, in <module>
    main()
  File "/home/team/replay/python/hosted/utils/testing.py", line 704, in main
    test_lora_hotswap()
  File "/home/team/replay/python/hosted/utils/testing.py", line 667, in test_lora_hotswap
    flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1808, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1899, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/peft.py", line 371, in load_lora_adapter
    hotswap_adapter_from_state_dict(
  File "/home/team/.local/lib/python3.11/site-packages/peft/utils/hotswap.py", line 431, in hotswap_adapter_from_state_dict
    raise RuntimeError(msg)
RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.14.ff.net.0.proj.lora_B.weight, single_transformer_blocks.7.attn.to_v.lora_B.weight, ...System Info
- π€ Diffusers version: 0.33.0.dev0
- Platform: Linux-5.10.0-34-cloud-amd64-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.28.1
- Transformers version: 4.50.3
- Accelerate version: 1.6.0
- PEFT version: 0.15.0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.5.3
- xFormers version: 0.0.29.post3
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?: