@@ -566,6 +566,25 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
566566
567567 return out
568568
569+ class QuantState :
570+ def __init__ (self , absmax , shape = None , code = None , blocksize = None , quant_type = None , dtype = None , offset = None , state2 = None ):
571+ self .absmax = absmax
572+ self .shape = shape
573+ self .code = code
574+ self .dtype = dtype
575+ self .blocksize = blocksize
576+ self .quant_type = quant_type
577+ self .offset = offset
578+ self .state2 = state2
579+ self .nested = state2 is not None
580+
581+ def to (self , device ):
582+ # make sure the quantization state is on the right device
583+ self .absmax = self .absmax .to (device )
584+ if self .nested :
585+ self .offset = self .offset .to (device )
586+ self .state2 .absmax = self .state2 .absmax .to (device )
587+ self .state2 .code = self .state2 .code .to (device )
569588
570589def quantize_blockwise (A : Tensor , code : Tensor = None , absmax : Tensor = None , out : Tensor = None , blocksize = 4096 , nested = False ) -> Tensor :
571590 """
@@ -633,16 +652,16 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
633652 offset = absmax .mean ()
634653 absmax -= offset
635654 qabsmax , state2 = quantize_blockwise (absmax , blocksize = blocksize , nested = False )
636- state = [ qabsmax , code , blocksize , nested , A .dtype , offset , state2 ]
655+ quant_state = QuantState ( absmax = qabsmax , code = code , blocksize = blocksize , dtype = A .dtype , offset = offset , state2 = state2 )
637656 else :
638- state = [ absmax , code , blocksize , nested , A .dtype , None , None ]
657+ quant_state = QuantState ( absmax = absmax , code = code , blocksize = blocksize , dtype = A .dtype )
639658
640- return out , state
659+ return out , quant_state
641660
642661
643662def dequantize_blockwise (
644663 A : Tensor ,
645- quant_state : Tuple [ Tensor , Tensor ] = None ,
664+ quant_state : QuantState = None ,
646665 absmax : Tensor = None ,
647666 code : Tensor = None ,
648667 out : Tensor = None ,
@@ -659,8 +678,8 @@ def dequantize_blockwise(
659678 ----------
660679 A : torch.Tensor
661680 The input 8-bit tensor.
662- quant_state : tuple(torch.Tensor, torch.Tensor)
663- Tuple of code and absmax values .
681+ quant_state : QuantState
682+ Object with code, absmax and other quantization state components .
664683 absmax : torch.Tensor
665684 The absmax values.
666685 code : torch.Tensor
@@ -681,36 +700,35 @@ def dequantize_blockwise(
681700 code = name2qmap ["dynamic" ]
682701
683702 if quant_state is None :
684- quant_state = (absmax , code , blocksize , False , torch .float32 , None , None )
685-
686- absmax , code , blocksize , nested , dtype , offset , state2 = quant_state
687-
688- if nested :
689- absmax = dequantize_blockwise (absmax , state2 )
690- absmax += offset
703+ quant_state = QuantState (absmax = absmax , code = code , blocksize = blocksize , dtype = torch .float32 )
704+
705+ absmax = quant_state .absmax
706+ if quant_state .nested :
707+ absmax = dequantize_blockwise (quant_state .absmax , quant_state .state2 )
708+ absmax += quant_state .offset
691709 if absmax .dtype != torch .float32 : absmax = absmax .float ()
692710
693711 if out is None :
694- out = torch .empty (A .shape , dtype = dtype , device = A .device )
712+ out = torch .empty (A .shape , dtype = quant_state . dtype , device = A .device )
695713
696714 if A .device .type != 'cpu' :
697715 device = pre_call (A .device )
698- code = code .to (A .device )
699- if blocksize not in [2048 , 4096 , 1024 , 512 , 256 , 128 , 64 ]:
700- raise ValueError (f"The blockwise of { blocksize } is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" )
716+ code = quant_state . code .to (A .device )
717+ if quant_state . blocksize not in [2048 , 4096 , 1024 , 512 , 256 , 128 , 64 ]:
718+ raise ValueError (f"The blockwise of { quant_state . blocksize } is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" )
701719 is_on_gpu ([A , absmax , out ])
702720 if out .dtype == torch .float32 :
703- lib .cdequantize_blockwise_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
721+ lib .cdequantize_blockwise_fp32 (get_ptr (quant_state . code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (A .numel ()))
704722 elif out .dtype == torch .float16 :
705- lib .cdequantize_blockwise_fp16 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
723+ lib .cdequantize_blockwise_fp16 (get_ptr (quant_state . code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (A .numel ()))
706724 elif out .dtype == torch .bfloat16 :
707- lib .cdequantize_blockwise_bf16 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
725+ lib .cdequantize_blockwise_bf16 (get_ptr (quant_state . code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (A .numel ()))
708726 else :
709727 raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
710728 post_call (A .device )
711729 else :
712- code = code .cpu ()
713- lib .cdequantize_blockwise_cpu_fp32 (get_ptr (quant_state [ 1 ] ), get_ptr (A ), get_ptr (quant_state [ 0 ] ), get_ptr (out ), ct .c_longlong (blocksize ), ct .c_longlong (A .numel ()))
730+ code = quant_state . code .cpu ()
731+ lib .cdequantize_blockwise_cpu_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (quant_state . absmax ), get_ptr (out ), ct .c_longlong (quant_state . blocksize ), ct .c_longlong (A .numel ()))
714732
715733 return out
716734
@@ -839,26 +857,26 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
839857 raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
840858 post_call (A .device )
841859
842- datatype = get_4bit_type (quant_type , device = A .device )
860+ code = get_4bit_type (quant_type , device = A .device )
843861
844862 if compress_statistics :
845863 offset = absmax .mean ()
846864 absmax -= offset
847865 qabsmax , state2 = quantize_blockwise (absmax , blocksize = 256 )
848866 del absmax
849- state = [ qabsmax , input_shape , A .dtype , blocksize , [ offset , state2 ], quant_type , datatype ]
867+ state = QuantState ( absmax = qabsmax , shape = input_shape , dtype = A .dtype , blocksize = blocksize , code = code , quant_type = quant_type , offset = offset , state2 = state2 )
850868 else :
851- state = [ absmax , input_shape , A .dtype , blocksize , None , quant_type , datatype ]
869+ state = QuantState ( absmax = absmax , shape = input_shape , dtype = A .dtype , blocksize = blocksize , code = code , quant_type = quant_type , )
852870
853871 return out , state
854872
855- def dequantize_fp4 (A : Tensor , quant_state : Tuple [ Tensor , Tensor ] = None , absmax : Tensor = None , out : Tensor = None , blocksize : int = 64 ) -> Tensor :
873+ def dequantize_fp4 (A : Tensor , quant_state : QuantState = None , absmax : Tensor = None , out : Tensor = None , blocksize : int = 64 ) -> Tensor :
856874 return dequantize_4bit (A , quant_state , absmax , out , blocksize , 'fp4' )
857875
858- def dequantize_nf4 (A : Tensor , quant_state : Tuple [ Tensor , Tensor ] = None , absmax : Tensor = None , out : Tensor = None , blocksize : int = 64 ) -> Tensor :
876+ def dequantize_nf4 (A : Tensor , quant_state : QuantState = None , absmax : Tensor = None , out : Tensor = None , blocksize : int = 64 ) -> Tensor :
859877 return dequantize_4bit (A , quant_state , absmax , out , blocksize , 'nf4' )
860878
861- def dequantize_4bit (A : Tensor ,quant_state : Tuple [ Tensor , Tensor ] = None , absmax : Tensor = None , out : Tensor = None , blocksize : int = 64 , quant_type = 'fp4' ) -> Tensor :
879+ def dequantize_4bit (A : Tensor , quant_state : QuantState = None , absmax : Tensor = None , out : Tensor = None , blocksize : int = 64 , quant_type = 'fp4' ) -> Tensor :
862880 """
863881 Dequantizes FP4 blockwise quantized values.
864882
@@ -868,8 +886,8 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
868886 ----------
869887 A : torch.Tensor
870888 The input 8-bit tensor (packed 4-bit values).
871- quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
872- Tuple of absmax values, original tensor shape and original dtype.
889+ quant_state : QuantState
890+ object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
873891 absmax : torch.Tensor
874892 The absmax values.
875893 out : torch.Tensor
@@ -892,41 +910,40 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
892910
893911 if quant_state is None :
894912 assert absmax is not None and out is not None
895- shape = out .shape
896- dtype = out .dtype
913+
914+ quant_state = QuantState (absmax = absmax , shape = out .shape , dtype = out .dtype , blocksize = blocksize , quant_type = quant_type )
915+
897916 else :
898- absmax , shape , dtype , blocksize , compressed_stats , quant_type , data_type = quant_state
917+ absmax = quant_state . absmax
899918
900919
901- if compressed_stats is not None :
902- offset , state2 = compressed_stats
903- absmax = dequantize_blockwise (absmax , state2 )
904- absmax += offset
920+ if quant_state .nested :
921+ absmax = dequantize_blockwise (quant_state .absmax , quant_state .state2 )
922+ absmax += quant_state .offset
905923 if absmax .dtype != torch .float32 : absmax = absmax .float ()
906924
907925 if out is None :
908- out = torch .empty (shape , dtype = dtype , device = A .device )
926+ out = torch .empty (quant_state . shape , dtype = quant_state . dtype , device = A .device )
909927
910928 n = out .numel ()
911929
912-
913930 device = pre_call (A .device )
914931 is_on_gpu ([A , absmax , out ])
915932 if out .dtype == torch .float32 :
916- if quant_type == 'fp4' :
917- lib .cdequantize_blockwise_fp32_fp4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (n ))
933+ if quant_state . quant_type == 'fp4' :
934+ lib .cdequantize_blockwise_fp32_fp4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (n ))
918935 else :
919- lib .cdequantize_blockwise_fp32_nf4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (n ))
936+ lib .cdequantize_blockwise_fp32_nf4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (n ))
920937 elif out .dtype == torch .float16 :
921- if quant_type == 'fp4' :
922- lib .cdequantize_blockwise_fp16_fp4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (n ))
938+ if quant_state . quant_type == 'fp4' :
939+ lib .cdequantize_blockwise_fp16_fp4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (n ))
923940 else :
924- lib .cdequantize_blockwise_fp16_nf4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (n ))
941+ lib .cdequantize_blockwise_fp16_nf4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (n ))
925942 elif out .dtype == torch .bfloat16 :
926- if quant_type == 'fp4' :
927- lib .cdequantize_blockwise_bf16_fp4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (n ))
943+ if quant_state . quant_type == 'fp4' :
944+ lib .cdequantize_blockwise_bf16_fp4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (n ))
928945 else :
929- lib .cdequantize_blockwise_bf16_nf4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (n ))
946+ lib .cdequantize_blockwise_bf16_nf4 (get_ptr (None ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (quant_state . blocksize ), ct .c_int (n ))
930947 else :
931948 raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
932949 post_call (A .device )
@@ -952,22 +969,22 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
952969
953970def dequantize (
954971 A : Tensor ,
955- quant_state : Tuple [Tensor , Tensor ] = None ,
972+ state : Tuple [Tensor , Tensor ] = None ,
956973 absmax : Tensor = None ,
957974 code : Tensor = None ,
958975 out : Tensor = None ,
959976) -> Tensor :
960- assert quant_state is not None or absmax is not None
961- if code is None and quant_state is None :
977+ assert state is not None or absmax is not None
978+ if code is None and state is None :
962979 if "dynamic" not in name2qmap :
963980 name2qmap ["dynamic" ] = create_dynamic_map ().to (A .device )
964981 code = name2qmap ["dynamic" ]
965982 code = code .to (A .device )
966983
967- if quant_state is None :
968- quant_state = (absmax , code )
969- out = dequantize_no_absmax (A , quant_state [1 ], out )
970- return out * quant_state [0 ]
984+ if state is None :
985+ state = (absmax , code )
986+ out = dequantize_no_absmax (A , state [1 ], out )
987+ return out * state [0 ]
971988
972989
973990def quantize_no_absmax (A : Tensor , code : Tensor , out : Tensor = None ) -> Tensor :
@@ -1472,23 +1489,22 @@ def gemv_4bit(
14721489 out : Tensor = None ,
14731490 transposed_A = False ,
14741491 transposed_B = False ,
1475- state = None
1492+ quant_state = None
14761493):
14771494 prev_device = pre_call (A .device )
14781495 #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
1479- if state is None :
1496+ if quant_state is None :
14801497 raise ValueError (f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )' )
14811498
14821499 if A .numel () != A .shape [- 1 ]:
14831500 raise ValueError (f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]' )
14841501
1485- Bshape = state [ 1 ]
1502+ Bshape = quant_state . shape
14861503 bout = Bshape [0 ]
1487- absmax , shape , dtype , blocksize , compressed_stats , quant_type , data_type = state
1488- if compressed_stats is not None :
1489- offset , state2 = compressed_stats
1490- absmax = dequantize_blockwise (absmax , state2 )
1491- absmax += offset
1504+ absmax = quant_state .absmax
1505+ if quant_state .nested :
1506+ absmax = dequantize_blockwise (quant_state .absmax , quant_state .state2 )
1507+ absmax += quant_state .offset
14921508
14931509 if out is None :
14941510 if len (A .shape ) == 3 :
@@ -1502,7 +1518,7 @@ def gemv_4bit(
15021518 lda = Bshape [0 ]
15031519 ldc = Bshape [0 ]
15041520 ldb = (A .shape [- 1 ]+ 1 )// 2
1505- is_on_gpu ([B , A , out , absmax , state [ - 1 ] ])
1521+ is_on_gpu ([B , A , out , absmax , quant_state . code ])
15061522 m = ct .c_int32 (m )
15071523 n = ct .c_int32 (n )
15081524 k = ct .c_int32 (k )
@@ -1512,11 +1528,11 @@ def gemv_4bit(
15121528
15131529 if B .dtype == torch .uint8 :
15141530 if A .dtype == torch .float16 :
1515- lib .cgemm_4bit_inference_naive_fp16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (state [ - 1 ] ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (state [ 3 ] ))
1531+ lib .cgemm_4bit_inference_naive_fp16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (quant_state . code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (quant_state . blocksize ))
15161532 elif A .dtype == torch .bfloat16 :
1517- lib .cgemm_4bit_inference_naive_bf16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (state [ - 1 ] ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (state [ 3 ] ))
1533+ lib .cgemm_4bit_inference_naive_bf16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (quant_state . code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (quant_state . blocksize ))
15181534 elif A .dtype == torch .float32 :
1519- lib .cgemm_4bit_inference_naive_fp32 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (state [ - 1 ] ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (state [ 3 ] ))
1535+ lib .cgemm_4bit_inference_naive_fp32 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (quant_state . code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (quant_state . blocksize ))
15201536 else :
15211537 raise NotImplementedError (f'Matmul not implemented for data type { A .dtype } ' )
15221538
@@ -1798,7 +1814,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
17981814
17991815def mm_dequant (
18001816 A ,
1801- quant_state ,
1817+ state ,
18021818 row_stats ,
18031819 col_stats ,
18041820 out = None ,
@@ -1808,7 +1824,7 @@ def mm_dequant(
18081824):
18091825 assert A .dtype == torch .int32
18101826 if bias is not None : assert bias .dtype == torch .float16
1811- out_shape = quant_state [0 ]
1827+ out_shape = state [0 ]
18121828 if len (out_shape ) == 3 :
18131829 out_shape = (out_shape [0 ] * out_shape [1 ], out_shape [2 ])
18141830
0 commit comments