diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index dc5ae465db..41ec60d6c3 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -88,9 +88,6 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): """Helper function to dequantize 4bit or 8bit bnb weights.""" import bitsandbytes as bnb - if state.SCB is None: - state.SCB = weight.SCB - device = weight.device cls_name = weight.__class__.__name__ @@ -98,6 +95,16 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) return dequantized + # 8bit case + if state is None: + raise ValueError( + "No `state` was passed for bnb 8bit quantized weights. Please open an issue on the PEFT repository and " + "report the error: https://github.com/huggingface/peft/issues" + ) + + if state.SCB is None: + state.SCB = weight.SCB + if hasattr(bnb.functional, "int8_vectorwise_dequant"): # Use bitsandbytes API if available (requires v0.45.0+) dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)