@@ -154,7 +154,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
154154
155155
156156# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 
157- def  dequantize_bnb_weight (weight : "torch.nn.Parameter" , state = None ):
157+ def  dequantize_bnb_weight (weight : "torch.nn.Parameter" , state = None ,  dtype :  torch . dtype   =   None ):
158158    """ 
159159    Helper function to dequantize 4bit or 8bit bnb weights. 
160160
@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
177177    if  state .SCB  is  None :
178178        state .SCB  =  weight .SCB 
179179
180-     im  =  torch .eye (weight .data .shape [- 1 ]).contiguous ().half ().to (weight .device )
181-     im , imt , SCim , SCimt , coo_tensorim  =  bnb .functional .double_quant (im )
182-     im , Sim  =  bnb .functional .transform (im , "col32" )
183-     if  state .CxB  is  None :
184-         state .CxB , state .SB  =  bnb .functional .transform (weight .data , to_order = state .formatB )
185-     out32 , Sout32  =  bnb .functional .igemmlt (im , state .CxB , Sim , state .SB )
186-     return  bnb .functional .mm_dequant (out32 , Sout32 , SCim , state .SCB , bias = None ).t ()
180+     if  hasattr (bnb .functional , "int8_vectorwise_dequant" ):
181+         # Use bitsandbytes API if available (requires v0.45.0+) 
182+         dequantized  =  bnb .functional .int8_vectorwise_dequant (weight .data , state .SCB )
183+     else :
184+         # Multiply by (scale/127) to dequantize. 
185+         dequantized  =  weight .data  *  state .SCB .view (- 1 , 1 ) *  7.874015718698502e-3 
186+ 
187+     if  dtype :
188+         dequantized  =  dequantized .to (dtype )
189+     return  dequantized 
187190
188191
189192def  _create_accelerate_new_hook (old_hook ):
@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):
205208
206209def  _dequantize_and_replace (
207210    model ,
211+     dtype ,
208212    modules_to_not_convert = None ,
209213    current_key_name = None ,
210214    quantization_config = None ,
@@ -244,7 +248,7 @@ def _dequantize_and_replace(
244248                else :
245249                    state  =  None 
246250
247-                 new_module .weight  =  torch .nn .Parameter (dequantize_bnb_weight (module .weight , state ))
251+                 new_module .weight  =  torch .nn .Parameter (dequantize_bnb_weight (module .weight , state ,  dtype ))
248252
249253                if  bias  is  not None :
250254                    new_module .bias  =  bias 
@@ -280,6 +284,7 @@ def dequantize_and_replace(
280284):
281285    model , has_been_replaced  =  _dequantize_and_replace (
282286        model ,
287+         model .dtype ,
283288        modules_to_not_convert = modules_to_not_convert ,
284289        quantization_config = quantization_config ,
285290    )
0 commit comments