28
28
groupwise_affine_quantize_tensor_from_qparams ,
29
29
groupwise_affine_dequantize_tensor_from_qparams ,
30
30
pack_tinygemm_scales_and_zeros ,
31
+ groupwise_affine_quantize_tensor ,
31
32
)
32
33
aten = torch .ops .aten
33
34
65
66
66
67
__all__ = [
67
68
"MultiInput" ,
68
- "WeightOnlyInt4Linear" ,
69
69
"Int4WeightOnlyGPTQQuantizer" ,
70
+ "Int4WeightOnlyQuantizer" ,
70
71
] + add_ons
71
72
72
73
if lm_eval_available :
@@ -117,7 +118,10 @@ def __init__(
117
118
118
119
@property
119
120
def eot_token_id (self ):
120
- return self ._tokenizer .eos_id ()
121
+ try :
122
+ return self ._tokenizer .eos_id ()
123
+ except :
124
+ return self ._tokenizer .eos_id
121
125
122
126
@property
123
127
def max_length (self ):
@@ -139,7 +143,10 @@ def tok_encode(self, string: str, **kwargs):
139
143
# TODO: verify this for multi-batch as well
140
144
tokens = self ._tokenizer .encode (string )
141
145
if hasattr (self ._tokenizer , "bos_id" ):
142
- tokens = [self ._tokenizer .bos_id ()] + tokens
146
+ try :
147
+ tokens = [self ._tokenizer .bos_id ()] + tokens
148
+ except :
149
+ tokens = [self ._tokenizer .bos_id ] + tokens
143
150
return tokens
144
151
145
152
def tok_decode (self , tokens ):
@@ -747,6 +754,12 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
747
754
def quantize (self , model : torch .nn .Module , inputs : List [MultiInput ], ** kwargs : Any ) -> torch .nn .Module :
748
755
pass
749
756
757
+ def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = None ):
758
+ k_divisible_by_groupsize = k % groupsize == 0
759
+ if inner_k_tiles is not None :
760
+ k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16 ) == 0
761
+ return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
762
+ return k_divisible_by_groupsize
750
763
751
764
def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
752
765
origin_x_size = x .size ()
@@ -767,7 +780,7 @@ def __init__(
767
780
bias = False , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , use_cuda = True ,
768
781
) -> None :
769
782
super ().__init__ ()
770
- self .padding = _check_linear_int4_k (in_features , groupsize , inner_k_tiles )
783
+ self .padding = not _check_linear_int4_k (in_features , groupsize , inner_k_tiles )
771
784
if self .padding :
772
785
from model import find_multiple
773
786
self .origin_in_features = in_features
@@ -806,14 +819,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
806
819
self .weight , self .scales_and_zeros , self .out_features , self .groupsize
807
820
)
808
821
809
-
810
- def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = None ):
811
- k_divisible_by_groupsize = k % groupsize == 0
812
- if inner_k_tiles is not None :
813
- k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16 ) == 0
814
- return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
815
- return k_divisible_by_groupsize
816
-
817
822
def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , use_cuda = True , skip_layer_func = None ):
818
823
819
824
for name , child in module .named_children ():
@@ -826,6 +831,83 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c
826
831
else :
827
832
replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , use_cuda , skip_layer_func )
828
833
834
+ class Int4WeightOnlyQuantizer (Quantizer ):
835
+ def __init__ (
836
+ self ,
837
+ groupsize : int = 256 ,
838
+ padding_allowed : bool = True ,
839
+ inner_k_tiles : Optional [int ] = 8 ,
840
+ ) -> None :
841
+ super ().__init__ ()
842
+ assert inner_k_tiles in [2 , 4 , 8 ]
843
+ assert groupsize in [32 , 64 , 128 , 256 ]
844
+
845
+ self .inner_k_tiles = inner_k_tiles
846
+ self .groupsize : int = groupsize
847
+ self .padding_allowed : bool = padding_allowed
848
+
849
+ @torch .no_grad ()
850
+ def _create_quantized_state_dict (
851
+ self , model : torch .nn .Module
852
+ ) -> Dict [str , torch .Tensor ]:
853
+ cur_state_dict = model .state_dict ()
854
+ for fqn , mod in model .named_modules ():
855
+ if isinstance (mod , torch .nn .Linear ):
856
+ assert not mod .bias
857
+ out_features = mod .out_features
858
+ in_features = mod .in_features
859
+ # assert out_features % 8 == 0, "require out_features % 8 == 0"
860
+ print (f"linear: { fqn } , in={ in_features } , out={ out_features } " )
861
+
862
+ assert (
863
+ in_features % self .groupsize == 0
864
+ ), f"require in_features:{ in_features } % self.groupsize:{ self .groupsize } == 0"
865
+
866
+ weight = mod .weight .data
867
+ if not _check_linear_int4_k (
868
+ in_features , self .groupsize , self .inner_k_tiles
869
+ ):
870
+ if self .padding_allowed :
871
+ from .utils import find_multiple
872
+ import torch .nn .functional as F
873
+ print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
874
+ padded_in_features = find_multiple (in_features , 1024 )
875
+ weight = F .pad (weight , pad = (0 , padded_in_features - in_features ))
876
+ else :
877
+ print (f"warning: { fqn } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
878
+ "and that groupsize and inner_k_tiles*16 evenly divide into it" )
879
+ continue
880
+ (
881
+ w_int4x8 ,
882
+ scales_and_zeros
883
+ ) = groupwise_affine_quantize_tensor (
884
+ weight ,
885
+ 4 , # n_bit
886
+ self .groupsize ,
887
+ )
888
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (w_int4x8 .to ("cuda" ), self .inner_k_tiles )
889
+ cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack .to ("cuda" )
890
+ cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros .to ("cuda" )
891
+ return cur_state_dict
892
+
893
+ def _convert_for_runtime (self , model : torch .nn .Module ) -> torch .nn .Module :
894
+ replace_linear_int4 (
895
+ model ,
896
+ self .groupsize ,
897
+ self .inner_k_tiles ,
898
+ self .padding_allowed ,
899
+ )
900
+ return model
901
+
902
+ def quantize (
903
+ self , model : torch .nn .Module , * args : Any , ** kwargs : Any
904
+ ) -> torch .nn .Module :
905
+ state_dict = self ._create_quantized_state_dict (model )
906
+ model = self ._convert_for_runtime (model )
907
+ # TODO: make it strict
908
+ model .load_state_dict (state_dict , strict = False )
909
+ return model
910
+
829
911
class Int4WeightOnlyGPTQQuantizer (GPTQQuantizer ):
830
912
def __init__ (
831
913
self ,
0 commit comments