@@ -52,6 +52,8 @@ def __init__(
5252        self .use_custom_update_cache_op  =  use_custom_update_cache_op 
5353        self .quantized_cache_dtype  =  torch .int8 
5454        self .cache_fp_type  =  torch .float32 
55+         self .return_float_values  =  True 
56+         self .max_context_length  =  max_context_length 
5557        cache_shape  =  (max_batch_size , max_context_length , n_heads , head_dim )
5658        scale_shape  =  (max_batch_size , max_context_length , n_heads , 1 )
5759        self .register_buffer (
@@ -61,17 +63,17 @@ def __init__(
6163            "v_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
6264        )
6365        self .register_buffer (
64-             "k_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
66+             "k_cache_scales" , torch .ones (scale_shape , dtype = torch .float32 )
6567        )
6668        self .register_buffer (
67-             "v_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
69+             "v_cache_scales" , torch .ones (scale_shape , dtype = torch .float32 )
6870        )
6971        if  cache_type  ==  QuantizedCacheType .AffineAsymmetric :
7072            self .register_buffer (
71-                 "k_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int64 )
73+                 "k_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int8 )
7274            )
7375            self .register_buffer (
74-                 "v_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int64 )
76+                 "v_cache_zero_points" , torch .ones (scale_shape , dtype = torch .int8 )
7577            )
7678
7779    def  _quantize (self , value ):
@@ -91,20 +93,15 @@ def _quantize(self, value):
9193        )
9294        return  quantized_value , scales , zero_points 
9395
94-     def  update (self , input_pos , k_val , v_val ):
95-         """ 
96-         k_val, v_val: [B, H, S, D] 
97-         return: [B, H, S, D] 
98-         However the storage is [B, S, H, D] so we incur transpose in, transpose out 
99-         This shall be removed by subsequent post-export graph pass 
100-         """ 
101-         k_val  =  k_val .transpose (1 , 2 )
102-         v_val  =  v_val .transpose (1 , 2 )
103-         # quantize current k_val and store it in the cache 
96+     def  _quantize_and_update (self , input_pos , k_val , v_val ):
10497        quantized_k_val , k_scales , k_zero_points  =  self ._quantize (k_val )
105- 
10698        quantized_v_val , v_scales , v_zero_points  =  self ._quantize (v_val )
10799
100+         k_scales  =  k_scales .to (torch .float32 )
101+         k_zero_points  =  k_zero_points .to (self .quantized_cache_dtype )
102+         v_scales  =  v_scales .to (torch .float32 )
103+         v_zero_points  =  v_zero_points .to (self .quantized_cache_dtype )
104+ 
108105        if  self .use_custom_update_cache_op :
109106            start_pos  =  input_pos [0 ].item ()
110107            _  =  torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
@@ -125,25 +122,30 @@ def update(self, input_pos, k_val, v_val):
125122            self .v_cache_scales [:, input_pos ] =  v_scales 
126123            self .v_cache_zero_points [:, input_pos ] =  v_zero_points 
127124
125+     def  _update_and_return_float_values (self , input_pos , k_val , v_val ):
126+         self ._quantize_and_update (input_pos , k_val , v_val )
127+ 
128128        k_out  =  torch .ops .quantized_decomposed .dequantize_per_token (
129129            self .k_cache ,
130-             self .k_cache_scales ,
131-             self .k_cache_zero_points ,
130+             self .k_cache_scales . to ( torch . float64 ) ,
131+             self .k_cache_zero_points . to ( torch . int64 ) ,
132132            torch .iinfo (self .quantized_cache_dtype ).min ,
133133            torch .iinfo (self .quantized_cache_dtype ).max ,
134134            self .quantized_cache_dtype ,
135135            self .cache_fp_type ,
136136        )
137137        v_out  =  torch .ops .quantized_decomposed .dequantize_per_token (
138138            self .v_cache ,
139-             self .v_cache_scales ,
140-             self .v_cache_zero_points ,
139+             self .v_cache_scales . to ( torch . float64 ) ,
140+             self .v_cache_zero_points . to ( torch . int64 ) ,
141141            torch .iinfo (self .quantized_cache_dtype ).min ,
142142            torch .iinfo (self .quantized_cache_dtype ).max ,
143143            self .quantized_cache_dtype ,
144144            self .cache_fp_type ,
145145        )
146146
147+         # When returning float values we jsut use the last value 
148+         # instead of dequantized value. 
147149        start_pos  =  input_pos [0 ].item ()
148150        if  self .use_custom_update_cache_op :
149151            _  =  torch .ops .llama .update_cache (k_val , k_out , start_pos )
@@ -152,6 +154,29 @@ def update(self, input_pos, k_val, v_val):
152154            k_out [:, input_pos ] =  k_val 
153155            v_out [:, input_pos ] =  v_val 
154156
157+         return  k_out , v_out 
158+ 
159+     def  _update_and_return_quantized_values (self , input_pos , k_val , v_val ):
160+         self ._quantize_and_update (input_pos , k_val , v_val )
161+ 
162+         return  self .k_cache , self .v_cache 
163+ 
164+     def  update (self , input_pos , k_val , v_val ):
165+         """ 
166+         k_val, v_val: [B, H, S, D] 
167+         return: [B, H, S, D] 
168+         However the storage is [B, S, H, D] so we incur transpose in, transpose out 
169+         This shall be removed by subsequent post-export graph pass 
170+         """ 
171+         k_val  =  k_val .transpose (1 , 2 )
172+         v_val  =  v_val .transpose (1 , 2 )
173+ 
174+         if  self .return_float_values :
175+             k_out , v_out  =  self ._update_and_return_float_values (input_pos , k_val , v_val )
176+         else :
177+             k_out , v_out  =  self ._update_and_return_quantized_values (
178+                 input_pos , k_val , v_val 
179+             )
155180        return  k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
156181
157182    @classmethod  
0 commit comments