2323class QuantizedCacheType (Enum ):
2424 AffineSymmetric = 0
2525 AffineAsymmetric = 1
26- AffineSymmetricGroupWise = 1
27- AffineAsymmetricGroupWise = 2
26+ AffineSymmetricGroupWise = 2
27+ AffineAsymmetricGroupWise = 3
2828
2929
3030class QuantizedKVCache (nn .Module ):
@@ -58,8 +58,12 @@ def __init__(
5858 else :
5959 cache_shape = (max_batch_size , max_seq_length , n_heads , head_dim )
6060 scale_shape = (max_batch_size , max_seq_length , n_heads , 1 )
61- self .register_buffer ("k_cache" , torch .zeros (cache_shape , dtype = torch .int8 ))
62- self .register_buffer ("v_cache" , torch .zeros (cache_shape , dtype = torch .int8 ))
61+ self .register_buffer (
62+ "k_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
63+ )
64+ self .register_buffer (
65+ "v_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
66+ )
6367 self .register_buffer (
6468 "k_cache_scales" , torch .ones (scale_shape , dtype = torch .double )
6569 )
@@ -74,43 +78,32 @@ def __init__(
7478 "v_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int64 )
7579 )
7680
77- def update (self , input_pos , k_val , v_val ):
78- # quantize current k_val and store it in the cache
79- k_scales , k_zero_points = (
81+ def _quantize (self , value ):
82+ scales , zero_points = (
8083 torch .ops .quantized_decomposed .choose_qparams_per_token_asymmetric .default (
81- k_val , torch . int8 # no other value is supported by this op anyway
84+ value , self . quantized_cache_dtype
8285 )
8386 )
84- quantized_k_val = torch .ops .quantized_decomposed .quantize_per_token (
85- k_val ,
86- k_scales ,
87- k_zero_points ,
88- torch .iinfo (torch . int8 ).min ,
89- torch .iinfo (torch . int8 ).max ,
90- torch . int8 ,
87+ quantized_value = torch .ops .quantized_decomposed .quantize_per_token (
88+ value ,
89+ scales ,
90+ zero_points ,
91+ torch .iinfo (self . quantized_cache_dtype ).min ,
92+ torch .iinfo (self . quantized_cache_dtype ).max ,
93+ self . quantized_cache_dtype ,
9194 )
95+ return quantized_value , scales , zero_points
9296
93- v_scales , v_zero_points = (
94- torch .ops .quantized_decomposed .choose_qparams_per_token_asymmetric (
95- v_val , torch .int8
96- )
97- )
98- quantized_v_val = torch .ops .quantized_decomposed .quantize_per_token (
99- v_val ,
100- v_scales ,
101- v_zero_points ,
102- torch .iinfo (torch .int8 ).min ,
103- torch .iinfo (torch .int8 ).max ,
104- torch .int8 ,
105- )
97+ def update (self , input_pos , k_val , v_val ):
98+ # quantize current k_val and store it in the cache
99+ quantized_k_val , k_scales , k_zero_points = self ._quantize (k_val )
100+
101+ quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
106102
107103 if self .enable_dynamic_shape :
108104 start_pos = input_pos [0 ].item ()
109105 torch ._check_is_size (start_pos )
110- if self .is_transposed :
111- dim_to_slice = 2
112- else :
113- dim_to_slice = 1
106+ dim_to_slice = 2 if self .is_transposed else 1
114107 torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
115108 seq_length = k_val .size (dim_to_slice )
116109 narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
@@ -154,17 +147,17 @@ def update(self, input_pos, k_val, v_val):
154147 self .k_cache ,
155148 self .k_cache_scales ,
156149 self .k_cache_zero_points ,
157- torch .iinfo (torch . int8 ).min ,
158- torch .iinfo (torch . int8 ).max ,
150+ torch .iinfo (self . quantized_cache_dtype ).min ,
151+ torch .iinfo (self . quantized_cache_dtype ).max ,
159152 self .quantized_cache_dtype ,
160153 self .cache_fp_type ,
161154 )
162155 v_out = torch .ops .quantized_decomposed .dequantize_per_token (
163156 self .v_cache ,
164157 self .v_cache_scales ,
165158 self .v_cache_zero_points ,
166- torch .iinfo (torch . int8 ).min ,
167- torch .iinfo (torch . int8 ).max ,
159+ torch .iinfo (self . quantized_cache_dtype ).min ,
160+ torch .iinfo (self . quantized_cache_dtype ).max ,
168161 self .quantized_cache_dtype ,
169162 self .cache_fp_type ,
170163 )
0 commit comments