@@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
153153 return model
154154
155155
156- # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1 /src/peft/utils/integrations.py#L41
157- def dequantize_bnb_weight (weight : "torch.nn.Parameter" , state = None ):
156+ # Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604 /src/peft/utils/integrations.py#L81
157+ def dequantize_bnb_weight (weight : "torch.nn.Parameter" , state = None , dtype : "torch.dtype" = None ):
158158 """
159159 Helper function to dequantize 4bit or 8bit bnb weights.
160160
@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
177177 if state .SCB is None :
178178 state .SCB = weight .SCB
179179
180- im = torch .eye (weight .data .shape [- 1 ]).contiguous ().half ().to (weight .device )
181- im , imt , SCim , SCimt , coo_tensorim = bnb .functional .double_quant (im )
182- im , Sim = bnb .functional .transform (im , "col32" )
183- if state .CxB is None :
184- state .CxB , state .SB = bnb .functional .transform (weight .data , to_order = state .formatB )
185- out32 , Sout32 = bnb .functional .igemmlt (im , state .CxB , Sim , state .SB )
186- return bnb .functional .mm_dequant (out32 , Sout32 , SCim , state .SCB , bias = None ).t ()
180+ if hasattr (bnb .functional , "int8_vectorwise_dequant" ):
181+ # Use bitsandbytes API if available (requires v0.45.0+)
182+ dequantized = bnb .functional .int8_vectorwise_dequant (weight .data , state .SCB )
183+ else :
184+ # Multiply by (scale/127) to dequantize.
185+ dequantized = weight .data * state .SCB .view (- 1 , 1 ) * 7.874015718698502e-3
186+
187+ if dtype :
188+ dequantized = dequantized .to (dtype )
189+ return dequantized
187190
188191
189192def _create_accelerate_new_hook (old_hook ):
@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):
205208
206209def _dequantize_and_replace (
207210 model ,
211+ dtype ,
208212 modules_to_not_convert = None ,
209213 current_key_name = None ,
210214 quantization_config = None ,
@@ -244,7 +248,7 @@ def _dequantize_and_replace(
244248 else :
245249 state = None
246250
247- new_module .weight = torch .nn .Parameter (dequantize_bnb_weight (module .weight , state ))
251+ new_module .weight = torch .nn .Parameter (dequantize_bnb_weight (module .weight , state , dtype ))
248252
249253 if bias is not None :
250254 new_module .bias = bias
@@ -263,9 +267,10 @@ def _dequantize_and_replace(
263267 if len (list (module .children ())) > 0 :
264268 _ , has_been_replaced = _dequantize_and_replace (
265269 module ,
266- modules_to_not_convert ,
267- current_key_name ,
268- quantization_config ,
270+ dtype = dtype ,
271+ modules_to_not_convert = modules_to_not_convert ,
272+ current_key_name = current_key_name ,
273+ quantization_config = quantization_config ,
269274 has_been_replaced = has_been_replaced ,
270275 )
271276 # Remove the last key for recursion
@@ -280,6 +285,7 @@ def dequantize_and_replace(
280285):
281286 model , has_been_replaced = _dequantize_and_replace (
282287 model ,
288+ dtype = model .dtype ,
283289 modules_to_not_convert = modules_to_not_convert ,
284290 quantization_config = quantization_config ,
285291 )
0 commit comments