Skip to content

Flux Controlnet Lora Fails To Load When Transformers Are Quantized.Β #10989

@CyberVy

Description

@CyberVy

Describe the bug

Failed to load Flux controlnet lora when the transformer is quantized to 4bit by bitsandbytes.
Maybe #10337 is a bit related to this issue.

Reproduction

import torch
from diffusers import FluxControlPipeline
pipe = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

Logs

/usr/local/lib/python3.11/dist-packages/diffusers/loaders/lora_pipeline.py in load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
   1549 
   1550         transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
-> 1551         has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
   1552             transformer, transformer_lora_state_dict, transformer_norm_state_dict
   1553         )

/usr/local/lib/python3.11/dist-packages/diffusers/loaders/lora_pipeline.py in _maybe_expand_transformer_param_shape_or_error_(cls, transformer, lora_state_dict, norm_state_dict, prefix)
   2019 
   2020                     with torch.device("meta"):
-> 2021                         expanded_module = torch.nn.Linear(
   2022                             in_features, out_features, bias=bias, dtype=module_weight.dtype
   2023                         )

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/linear.py in __init__(self, in_features, out_features, bias, device, dtype)
    103         self.in_features = in_features
    104         self.out_features = out_features
--> 105         self.weight = Parameter(
    106             torch.empty((out_features, in_features), **factory_kwargs)
    107         )

/usr/local/lib/python3.11/dist-packages/torch/nn/parameter.py in __new__(cls, data, requires_grad)
     44             # For ease of BC maintenance, keep this path for standard Tensor.
     45             # Eventually (tm), we should change the behavior for standard Tensor to match.
---> 46             return torch.Tensor._make_subclass(cls, data, requires_grad)
     47 
     48         # Path for custom tensors: set a flag on the instance to indicate parameter-ness.

RuntimeError: Only Tensors of floating point and complex dtype can require gradients

System Info

  • πŸ€— Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Running on Google Colab?: Yes
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.10.4 (gpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.3
  • Accelerate version: 1.3.0
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.3
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: Tesla T4, 15360 MiB

Who can help?

@hlky

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinglora

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions