Skip to content
8 changes: 5 additions & 3 deletions src/diffusers/quantizers/bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc

if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
logger.warning_once(
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
)
msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
if dtype:
msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
output_tensor = output_tensor.to(dtype)
logger.warning_once(msg)
return output_tensor

if state.SCB is None:
Expand Down