66
77import logging
88from enum import Enum
9- from typing import Tuple
9+ from typing import Optional , Tuple
1010
1111import torch
1212import torch .nn as nn
@@ -93,7 +93,7 @@ def _quantize(self, value):
9393 )
9494 return quantized_value , scales , zero_points
9595
96- def _quantize_and_update (self , input_pos , k_val , v_val ):
96+ def _quantize_and_update (self , input_pos , k_val , v_val , indices = None ):
9797 quantized_k_val , k_scales , k_zero_points = self ._quantize (k_val )
9898 quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
9999
@@ -104,26 +104,57 @@ def _quantize_and_update(self, input_pos, k_val, v_val):
104104
105105 if self .use_custom_update_cache_op :
106106 start_pos = input_pos [0 ].item ()
107- _ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
108- _ = torch .ops .llama .update_cache (k_scales , self .k_cache_scales , start_pos )
109- _ = torch .ops .llama .update_cache (
110- k_zero_points , self .k_cache_zero_points , start_pos
111- )
112- _ = torch .ops .llama .update_cache (quantized_v_val , self .v_cache , start_pos )
113- _ = torch .ops .llama .update_cache (v_scales , self .v_cache_scales , start_pos )
114- _ = torch .ops .llama .update_cache (
115- v_zero_points , self .v_cache_zero_points , start_pos
116- )
107+ if indices is not None :
108+ _ = torch .ops .llama .update_cache_with_indices (
109+ quantized_k_val , self .k_cache , start_pos , indices
110+ )
111+ _ = torch .ops .llama .update_cache_with_indices (
112+ k_scales , self .k_cache_scales , start_pos , indices
113+ )
114+ _ = torch .ops .llama .update_cache_with_indices (
115+ k_zero_points , self .k_cache_zero_points , start_pos , indices
116+ )
117+ _ = torch .ops .llama .update_cache_with_indices (
118+ quantized_v_val , self .v_cache , start_pos , indices
119+ )
120+ _ = torch .ops .llama .update_cache_with_indices (
121+ v_scales , self .v_cache_scales , start_pos , indices
122+ )
123+ _ = torch .ops .llama .update_cache_with_indices (
124+ v_zero_points , self .v_cache_zero_points , start_pos , indices
125+ )
126+ else :
127+ _ = torch .ops .llama .update_cache (
128+ quantized_k_val , self .k_cache , start_pos
129+ )
130+ _ = torch .ops .llama .update_cache (
131+ k_scales , self .k_cache_scales , start_pos
132+ )
133+ _ = torch .ops .llama .update_cache (
134+ k_zero_points , self .k_cache_zero_points , start_pos
135+ )
136+ _ = torch .ops .llama .update_cache (
137+ quantized_v_val , self .v_cache , start_pos
138+ )
139+ _ = torch .ops .llama .update_cache (
140+ v_scales , self .v_cache_scales , start_pos
141+ )
142+ _ = torch .ops .llama .update_cache (
143+ v_zero_points , self .v_cache_zero_points , start_pos
144+ )
117145 else :
146+ assert indices is None , "Indices not supported for this path"
147+ # Following is also broken because in prefill input_pos = [0]
148+ # but we need to update some slice of cache
118149 self .k_cache [:, input_pos ] = quantized_k_val
119150 self .k_cache_scales [:, input_pos ] = k_scales
120151 self .k_cache_zero_points [:, input_pos ] = k_zero_points
121152 self .v_cache [:, input_pos ] = quantized_v_val
122153 self .v_cache_scales [:, input_pos ] = v_scales
123154 self .v_cache_zero_points [:, input_pos ] = v_zero_points
124155
125- def _update_and_return_float_values (self , input_pos , k_val , v_val ):
126- self ._quantize_and_update (input_pos , k_val , v_val )
156+ def _update_and_return_float_values (self , input_pos , k_val , v_val , indices = None ):
157+ self ._quantize_and_update (input_pos , k_val , v_val , indices )
127158
128159 k_out = torch .ops .quantized_decomposed .dequantize_per_token (
129160 self .k_cache ,
@@ -144,24 +175,34 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
144175 self .cache_fp_type ,
145176 )
146177
147- # When returning float values we jsut use the last value
178+ # When returning float values we just use the last value
148179 # instead of dequantized value.
149180 start_pos = input_pos [0 ].item ()
150181 if self .use_custom_update_cache_op :
151- _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
152- _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
182+ if indices is not None :
183+ _ = torch .ops .llama .update_cache_with_indices (
184+ k_val , k_out , start_pos , indices
185+ )
186+ _ = torch .ops .llama .update_cache_with_indices (
187+ v_val , v_out , start_pos , indices
188+ )
189+ else :
190+ _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
191+ _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
153192 else :
154193 k_out [:, input_pos ] = k_val
155194 v_out [:, input_pos ] = v_val
156195
157196 return k_out , v_out
158197
159- def _update_and_return_quantized_values (self , input_pos , k_val , v_val ):
160- self ._quantize_and_update (input_pos , k_val , v_val )
198+ def _update_and_return_quantized_values (
199+ self , input_pos , k_val , v_val , indices = None
200+ ):
201+ self ._quantize_and_update (input_pos , k_val , v_val , indices )
161202
162203 return self .k_cache , self .v_cache
163204
164- def update (self , input_pos , k_val , v_val ):
205+ def update (self , input_pos , k_val , v_val , indices = None ):
165206 """
166207 k_val, v_val: [B, H, S, D]
167208 return: [B, H, S, D]
@@ -172,10 +213,12 @@ def update(self, input_pos, k_val, v_val):
172213 v_val = v_val .transpose (1 , 2 )
173214
174215 if self .return_float_values :
175- k_out , v_out = self ._update_and_return_float_values (input_pos , k_val , v_val )
216+ k_out , v_out = self ._update_and_return_float_values (
217+ input_pos , k_val , v_val , indices
218+ )
176219 else :
177220 k_out , v_out = self ._update_and_return_quantized_values (
178- input_pos , k_val , v_val
221+ input_pos , k_val , v_val , indices
179222 )
180223 return k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
181224
@@ -277,14 +320,28 @@ def __init__(
277320 )
278321
279322 def update (
280- self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
323+ self ,
324+ input_pos : torch .Tensor ,
325+ k_val : torch .Tensor ,
326+ v_val : torch .Tensor ,
327+ indices : Optional [torch .Tensor ] = None ,
281328 ) -> Tuple [torch .Tensor , torch .Tensor ]:
282329 # input_pos: [S], k_val: [B, H, S, D]
283330 k_val = k_val .transpose (1 , 2 )
284331 v_val = v_val .transpose (1 , 2 )
285332 start_pos = input_pos [0 ].item ()
286- _ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
287- _ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
333+
334+ if indices is not None :
335+ _ = torch .ops .llama .update_cache_with_indices (
336+ k_val , self .k_cache , start_pos , indices
337+ )
338+ _ = torch .ops .llama .update_cache_with_indices (
339+ v_val , self .v_cache , start_pos , indices
340+ )
341+ else :
342+ _ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
343+ _ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
344+
288345 return (
289346 self .k_cache .transpose (1 , 2 ),
290347 self .v_cache .transpose (1 , 2 ),
0 commit comments