diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index e150281e81ce..5476b93a4cc2 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -171,9 +171,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc if cls_name == "Params4bit": output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - logger.warning_once( - f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" - ) + msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" + if dtype: + msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}" + output_tensor = output_tensor.to(dtype) + logger.warning_once(msg) return output_tensor if state.SCB is None: