Skip to content

Commit 5801679

Browse files
committed
improve dtype mismatch handling for bnb + lora.
1 parent 0706786 commit 5801679

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
9494
weight_on_cpu = True
9595

9696
if is_bnb_4bit_quantized:
97+
if module.weight.quant_state.dtype != model.dtype:
98+
raise ValueError(
99+
f"Model is in {model.dtype} dtype while the current module weight will be dequantized to {module.weight.quant_state.dtype} dtype. "
100+
f"Please pass {module.weight.quant_state.dtype} as `torch_dtype` in `from_pretrained()`."
101+
)
97102
module_weight = dequantize_bnb_weight(
98103
module.weight.cuda() if weight_on_cpu else module.weight,
99104
state=module.weight.quant_state,

0 commit comments

Comments
 (0)