77from bitsandbytes .functional import (
88 QuantState ,
99 get_4bit_type ,
10+ create_dynamic_map ,
1011)
1112
1213try :
@@ -279,8 +280,9 @@ def mm_dequant_impl(
279280 0.8333333 : 3 , # 0b0011
280281}
281282
283+ INT8_QUANT_TABLE = create_dynamic_map ().tolist ()
284+
282285
283- @_maybe_torch_compile
284286def quantize_4bit_impl (
285287 A : Tensor ,
286288 absmax : Tensor = None ,
@@ -314,7 +316,7 @@ def quantize_4bit_impl(
314316 tuple(torch.Tensor, torch.Size, torch.dtype, int):
315317 The quantization state to undo the quantization.
316318 """
317- if quant_type not in ["nf4" , "fp4" ]:
319+ if quant_type not in ["nf4" , "fp4" , "int8" ]:
318320 raise NotImplementedError (f"4-bit quantization data type { quant_type } is not implemented for CPU/XPU." )
319321 if quant_type == "fp4" :
320322 warnings .warn ("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance." )
@@ -355,14 +357,35 @@ def quantize_4bit_impl(
355357 for key , val in FP4_QUANT_TABLE .items ():
356358 out_uint8 [abs_scaled_A > key ] = val
357359 out_uint8 += sign .to (torch .uint8 ) * 8
358- if out_uint8 . size ( - 1 ) % 2 :
359- out_uint8 = torch . nn . functional . pad ( out_uint8 , ( 0 , 1 ), value = 0 )
360- out [:] = out_uint8 [1 :: 2 ]. bitwise_left_shift ( 4 ). bitwise_or_ ( out_uint8 [:: 2 ])
360+ elif quant_type == "int8" :
361+ for i in range ( len ( INT8_QUANT_TABLE )):
362+ out_uint8 [scaled_A > INT8_QUANT_TABLE [ i ]] = i
361363
362- code = get_4bit_type (quant_type , device = A .device )
364+ if quant_type != "int8" :
365+ if out_uint8 .size (- 1 ) % 2 :
366+ out_uint8 = torch .nn .functional .pad (out_uint8 , (0 , 1 ), value = 0 )
367+ out [:] = out_uint8 [1 ::2 ].bitwise_left_shift (4 ).bitwise_or_ (out_uint8 [::2 ])
368+
369+ code = get_4bit_type (quant_type , device = A .device )
370+ else :
371+ out = out_uint8
372+ code = torch .Tensor (INT8_QUANT_TABLE , device = A .device )
363373
364374 if compress_statistics :
365- raise NotImplementedError ("bnb_4bit_use_double_quant is not supported yet for CPU/XPU" )
375+ offset = absmax .mean ()
376+ absmax -= offset
377+ qabsmax , state2 = quantize_4bit_impl (absmax , blocksize = 256 , quant_type = "int8" )
378+ del absmax
379+ state = QuantState (
380+ absmax = qabsmax ,
381+ shape = input_shape ,
382+ dtype = A .dtype ,
383+ blocksize = blocksize ,
384+ code = code ,
385+ quant_type = quant_type ,
386+ offset = offset ,
387+ state2 = state2 ,
388+ )
366389 else :
367390 state = QuantState (
368391 absmax = absmax ,
@@ -376,6 +399,14 @@ def quantize_4bit_impl(
376399 return out .unsqueeze (0 ), state
377400
378401
402+ def dequant_8bit (A , offset , quant_state ):
403+ assert A .dtype == torch .uint8
404+ absmax = quant_state .code [A .reshape (- 1 ).int ()]
405+ absmax += offset
406+ absmax = (absmax .view (- 1 , 256 ) * quant_state .absmax .view (- 1 , 1 )).reshape (quant_state .shape ).to (quant_state .dtype )
407+ return absmax
408+
409+
379410@_maybe_torch_compile
380411def dequantize_4bit_impl (
381412 A : Tensor ,
@@ -438,7 +469,7 @@ def dequantize_4bit_impl(
438469 )
439470
440471 if quant_state .nested :
441- raise NotImplementedError ( "bnb_4bit_use_double_quant is not supported yet for CPU/XPU" )
472+ absmax = dequant_8bit ( absmax , quant_state . offset , quant_state . state2 )
442473
443474 if ipex_cpu_only and _ipex_cpu_version_prereq (2 , 5 ) and getattr (quant_state , "ipex" , False ):
444475 A = torch .ops .ipex_prepack .woq_linear_unpack_weight (A , "nf4" , quant_state .shape , 2 )
0 commit comments