Skip to content
Merged
5 changes: 5 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
weight_on_cpu = True

if is_bnb_4bit_quantized:
if module.weight.quant_state.dtype != model.dtype:
raise ValueError(
f"Model is in {model.dtype} dtype while the current module weight will be dequantized to {module.weight.quant_state.dtype} dtype. "
f"Please pass {module.weight.quant_state.dtype} as `torch_dtype` in `from_pretrained()`."
)
module_weight = dequantize_bnb_weight(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we specified dtype = model.dtype in the dequantize_bnb_weight, won't the module_weights have the same dtype as model ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will. But the LoRA params would not be in that dtype as they are derived early from the module_weight data dtype. This is why in the error trace, the error happens in peft.

Copy link
Member

@SunMarc SunMarc Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize, we have the following right ? :

  • changed loras params using dtype from module_weight (this is maybe where module.weight.quant_state.dtype was used)
  • dequantized module_weight using dtype from model.dtype (so we are not using module.weight.quant_state.dtype actually no ?). model.dtype value comes from torch_dtype.

-> dtype mismatch issue due to loras param not having the same dtype as module_weight

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. We don't really have any special treatment to handle LoRA param dtype. Ccing @BenjaminBossan here.

dequantized module_weight using dtype from model.dtype (so we are not using module.weight.quant_state.dtype actually no ?). model.dtype value comes from torch_dtype.

Well, we use the quant_state:

output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)

But then we also perform another type-casting:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify: This is unrelated to the LoRA parameters. Instead, what happens is that a PEFT LoraLayer wraps the base layer and calls self.base_layer(x), which should just be the result from the original layer. Due to the change in dtype, we will encounter the dtype mismatch there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from diffusers import DiffusionPipeline, FluxControlPipeline
from PIL import Image
import torch

pipe = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16).to("cuda")

pipe("a dog", control_image=Image.new(mode="RGB", size=(256, 256)))

This works though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it happens inside the LoRA layer, but what I mean is that the LoRA weights are not involved, it's the call to the base layer that is causing the issue.

module.weight.cuda() if weight_on_cpu else module.weight,
state=module.weight.quant_state,
Expand Down
14 changes: 14 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,20 @@ def test_lora_loading(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")

def test_loading_lora_with_incorrect_dtype_raises_error(self):
self.tearDown()
model_dtype = torch.bfloat16
# https://huggingface.co/eramth/flux-4bit/blob/main/transformer/config.json#L23
actual_dtype = torch.float16
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16)
self.pipeline_4bit.enable_model_cpu_offload()
with self.assertRaises(ValueError) as err_context:
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
assert (
f"Model is in {model_dtype} dtype while the current module weight will be dequantized to {actual_dtype} dtype."
in str(err_context.exception)
)


@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
Expand Down
Loading