@@ -52,6 +52,8 @@ def __init__(
5252 self .use_custom_update_cache_op = use_custom_update_cache_op
5353 self .quantized_cache_dtype = torch .int8
5454 self .cache_fp_type = torch .float32
55+ self .return_float_values = True
56+ self .max_context_length = max_context_length
5557 cache_shape = (max_batch_size , max_context_length , n_heads , head_dim )
5658 scale_shape = (max_batch_size , max_context_length , n_heads , 1 )
5759 self .register_buffer (
@@ -61,17 +63,17 @@ def __init__(
6163 "v_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
6264 )
6365 self .register_buffer (
64- "k_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
66+ "k_cache_scales" , torch .ones (scale_shape , dtype = torch .float32 )
6567 )
6668 self .register_buffer (
67- "v_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
69+ "v_cache_scales" , torch .ones (scale_shape , dtype = torch .float32 )
6870 )
6971 if cache_type == QuantizedCacheType .AffineAsymmetric :
7072 self .register_buffer (
71- "k_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int64 )
73+ "k_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int8 )
7274 )
7375 self .register_buffer (
74- "v_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int64 )
76+ "v_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int8 )
7577 )
7678
7779 def _quantize (self , value ):
@@ -91,20 +93,15 @@ def _quantize(self, value):
9193 )
9294 return quantized_value , scales , zero_points
9395
94- def update (self , input_pos , k_val , v_val ):
95- """
96- k_val, v_val: [B, H, S, D]
97- return: [B, H, S, D]
98- However the storage is [B, S, H, D] so we incur transpose in, transpose out
99- This shall be removed by subsequent post-export graph pass
100- """
101- k_val = k_val .transpose (1 , 2 )
102- v_val = v_val .transpose (1 , 2 )
103- # quantize current k_val and store it in the cache
96+ def _quantize_and_update (self , input_pos , k_val , v_val ):
10497 quantized_k_val , k_scales , k_zero_points = self ._quantize (k_val )
105-
10698 quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
10799
100+ k_scales = k_scales .to (torch .float32 )
101+ k_zero_points = k_zero_points .to (self .quantized_cache_dtype )
102+ v_scales = v_scales .to (torch .float32 )
103+ v_zero_points = v_zero_points .to (self .quantized_cache_dtype )
104+
108105 if self .use_custom_update_cache_op :
109106 start_pos = input_pos [0 ].item ()
110107 _ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
@@ -125,25 +122,30 @@ def update(self, input_pos, k_val, v_val):
125122 self .v_cache_scales [:, input_pos ] = v_scales
126123 self .v_cache_zero_points [:, input_pos ] = v_zero_points
127124
125+ def _update_and_return_float_values (self , input_pos , k_val , v_val ):
126+ self ._quantize_and_update (input_pos , k_val , v_val )
127+
128128 k_out = torch .ops .quantized_decomposed .dequantize_per_token (
129129 self .k_cache ,
130- self .k_cache_scales ,
131- self .k_cache_zero_points ,
130+ self .k_cache_scales . to ( torch . float64 ) ,
131+ self .k_cache_zero_points . to ( torch . int64 ) ,
132132 torch .iinfo (self .quantized_cache_dtype ).min ,
133133 torch .iinfo (self .quantized_cache_dtype ).max ,
134134 self .quantized_cache_dtype ,
135135 self .cache_fp_type ,
136136 )
137137 v_out = torch .ops .quantized_decomposed .dequantize_per_token (
138138 self .v_cache ,
139- self .v_cache_scales ,
140- self .v_cache_zero_points ,
139+ self .v_cache_scales . to ( torch . float64 ) ,
140+ self .v_cache_zero_points . to ( torch . int64 ) ,
141141 torch .iinfo (self .quantized_cache_dtype ).min ,
142142 torch .iinfo (self .quantized_cache_dtype ).max ,
143143 self .quantized_cache_dtype ,
144144 self .cache_fp_type ,
145145 )
146146
147+ # When returning float values we jsut use the last value
148+ # instead of dequantized value.
147149 start_pos = input_pos [0 ].item ()
148150 if self .use_custom_update_cache_op :
149151 _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
@@ -152,6 +154,29 @@ def update(self, input_pos, k_val, v_val):
152154 k_out [:, input_pos ] = k_val
153155 v_out [:, input_pos ] = v_val
154156
157+ return k_out , v_out
158+
159+ def _update_and_return_quantized_values (self , input_pos , k_val , v_val ):
160+ self ._quantize_and_update (input_pos , k_val , v_val )
161+
162+ return self .k_cache , self .v_cache
163+
164+ def update (self , input_pos , k_val , v_val ):
165+ """
166+ k_val, v_val: [B, H, S, D]
167+ return: [B, H, S, D]
168+ However the storage is [B, S, H, D] so we incur transpose in, transpose out
169+ This shall be removed by subsequent post-export graph pass
170+ """
171+ k_val = k_val .transpose (1 , 2 )
172+ v_val = v_val .transpose (1 , 2 )
173+
174+ if self .return_float_values :
175+ k_out , v_out = self ._update_and_return_float_values (input_pos , k_val , v_val )
176+ else :
177+ k_out , v_out = self ._update_and_return_quantized_values (
178+ input_pos , k_val , v_val
179+ )
155180 return k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
156181
157182 @classmethod
0 commit comments