66
77import  logging 
88from  enum  import  Enum 
9- from  typing  import  Tuple 
9+ from  typing  import  Optional ,  Tuple 
1010
1111import  torch 
1212import  torch .nn  as  nn 
@@ -93,7 +93,7 @@ def _quantize(self, value):
9393        )
9494        return  quantized_value , scales , zero_points 
9595
96-     def  _quantize_and_update (self , input_pos , k_val , v_val ):
96+     def  _quantize_and_update (self , input_pos , k_val , v_val ,  indices = None ):
9797        quantized_k_val , k_scales , k_zero_points  =  self ._quantize (k_val )
9898        quantized_v_val , v_scales , v_zero_points  =  self ._quantize (v_val )
9999
@@ -104,26 +104,57 @@ def _quantize_and_update(self, input_pos, k_val, v_val):
104104
105105        if  self .use_custom_update_cache_op :
106106            start_pos  =  input_pos [0 ].item ()
107-             _  =  torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
108-             _  =  torch .ops .llama .update_cache (k_scales , self .k_cache_scales , start_pos )
109-             _  =  torch .ops .llama .update_cache (
110-                 k_zero_points , self .k_cache_zero_points , start_pos 
111-             )
112-             _  =  torch .ops .llama .update_cache (quantized_v_val , self .v_cache , start_pos )
113-             _  =  torch .ops .llama .update_cache (v_scales , self .v_cache_scales , start_pos )
114-             _  =  torch .ops .llama .update_cache (
115-                 v_zero_points , self .v_cache_zero_points , start_pos 
116-             )
107+             if  indices  is  not None :
108+                 _  =  torch .ops .llama .update_cache_with_indices (
109+                     quantized_k_val , self .k_cache , start_pos , indices 
110+                 )
111+                 _  =  torch .ops .llama .update_cache_with_indices (
112+                     k_scales , self .k_cache_scales , start_pos , indices 
113+                 )
114+                 _  =  torch .ops .llama .update_cache_with_indices (
115+                     k_zero_points , self .k_cache_zero_points , start_pos , indices 
116+                 )
117+                 _  =  torch .ops .llama .update_cache_with_indices (
118+                     quantized_v_val , self .v_cache , start_pos , indices 
119+                 )
120+                 _  =  torch .ops .llama .update_cache_with_indices (
121+                     v_scales , self .v_cache_scales , start_pos , indices 
122+                 )
123+                 _  =  torch .ops .llama .update_cache_with_indices (
124+                     v_zero_points , self .v_cache_zero_points , start_pos , indices 
125+                 )
126+             else :
127+                 _  =  torch .ops .llama .update_cache (
128+                     quantized_k_val , self .k_cache , start_pos 
129+                 )
130+                 _  =  torch .ops .llama .update_cache (
131+                     k_scales , self .k_cache_scales , start_pos 
132+                 )
133+                 _  =  torch .ops .llama .update_cache (
134+                     k_zero_points , self .k_cache_zero_points , start_pos 
135+                 )
136+                 _  =  torch .ops .llama .update_cache (
137+                     quantized_v_val , self .v_cache , start_pos 
138+                 )
139+                 _  =  torch .ops .llama .update_cache (
140+                     v_scales , self .v_cache_scales , start_pos 
141+                 )
142+                 _  =  torch .ops .llama .update_cache (
143+                     v_zero_points , self .v_cache_zero_points , start_pos 
144+                 )
117145        else :
146+             assert  indices  is  None , "Indices not supported for this path" 
147+             # Following is also broken because in prefill input_pos = [0] 
148+             # but we need to update some slice of cache 
118149            self .k_cache [:, input_pos ] =  quantized_k_val 
119150            self .k_cache_scales [:, input_pos ] =  k_scales 
120151            self .k_cache_zero_points [:, input_pos ] =  k_zero_points 
121152            self .v_cache [:, input_pos ] =  quantized_v_val 
122153            self .v_cache_scales [:, input_pos ] =  v_scales 
123154            self .v_cache_zero_points [:, input_pos ] =  v_zero_points 
124155
125-     def  _update_and_return_float_values (self , input_pos , k_val , v_val ):
126-         self ._quantize_and_update (input_pos , k_val , v_val )
156+     def  _update_and_return_float_values (self , input_pos , k_val , v_val ,  indices = None ):
157+         self ._quantize_and_update (input_pos , k_val , v_val ,  indices )
127158
128159        k_out  =  torch .ops .quantized_decomposed .dequantize_per_token (
129160            self .k_cache ,
@@ -144,24 +175,34 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
144175            self .cache_fp_type ,
145176        )
146177
147-         # When returning float values we jsut  use the last value 
178+         # When returning float values we just  use the last value 
148179        # instead of dequantized value. 
149180        start_pos  =  input_pos [0 ].item ()
150181        if  self .use_custom_update_cache_op :
151-             _  =  torch .ops .llama .update_cache (k_val , k_out , start_pos )
152-             _  =  torch .ops .llama .update_cache (v_val , v_out , start_pos )
182+             if  indices  is  not None :
183+                 _  =  torch .ops .llama .update_cache_with_indices (
184+                     k_val , k_out , start_pos , indices 
185+                 )
186+                 _  =  torch .ops .llama .update_cache_with_indices (
187+                     v_val , v_out , start_pos , indices 
188+                 )
189+             else :
190+                 _  =  torch .ops .llama .update_cache (k_val , k_out , start_pos )
191+                 _  =  torch .ops .llama .update_cache (v_val , v_out , start_pos )
153192        else :
154193            k_out [:, input_pos ] =  k_val 
155194            v_out [:, input_pos ] =  v_val 
156195
157196        return  k_out , v_out 
158197
159-     def  _update_and_return_quantized_values (self , input_pos , k_val , v_val ):
160-         self ._quantize_and_update (input_pos , k_val , v_val )
198+     def  _update_and_return_quantized_values (
199+         self , input_pos , k_val , v_val , indices = None 
200+     ):
201+         self ._quantize_and_update (input_pos , k_val , v_val , indices )
161202
162203        return  self .k_cache , self .v_cache 
163204
164-     def  update (self , input_pos , k_val , v_val ):
205+     def  update (self , input_pos , k_val , v_val ,  indices = None ):
165206        """ 
166207        k_val, v_val: [B, H, S, D] 
167208        return: [B, H, S, D] 
@@ -172,10 +213,12 @@ def update(self, input_pos, k_val, v_val):
172213        v_val  =  v_val .transpose (1 , 2 )
173214
174215        if  self .return_float_values :
175-             k_out , v_out  =  self ._update_and_return_float_values (input_pos , k_val , v_val )
216+             k_out , v_out  =  self ._update_and_return_float_values (
217+                 input_pos , k_val , v_val , indices 
218+             )
176219        else :
177220            k_out , v_out  =  self ._update_and_return_quantized_values (
178-                 input_pos , k_val , v_val 
221+                 input_pos , k_val , v_val ,  indices 
179222            )
180223        return  k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
181224
@@ -277,14 +320,28 @@ def __init__(
277320        )
278321
279322    def  update (
280-         self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor 
323+         self ,
324+         input_pos : torch .Tensor ,
325+         k_val : torch .Tensor ,
326+         v_val : torch .Tensor ,
327+         indices : Optional [torch .Tensor ] =  None ,
281328    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
282329        # input_pos: [S], k_val: [B, H, S, D] 
283330        k_val  =  k_val .transpose (1 , 2 )
284331        v_val  =  v_val .transpose (1 , 2 )
285332        start_pos  =  input_pos [0 ].item ()
286-         _  =  torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
287-         _  =  torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
333+ 
334+         if  indices  is  not None :
335+             _  =  torch .ops .llama .update_cache_with_indices (
336+                 k_val , self .k_cache , start_pos , indices 
337+             )
338+             _  =  torch .ops .llama .update_cache_with_indices (
339+                 v_val , self .v_cache , start_pos , indices 
340+             )
341+         else :
342+             _  =  torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
343+             _  =  torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
344+ 
288345        return  (
289346            self .k_cache .transpose (1 , 2 ),
290347            self .v_cache .transpose (1 , 2 ),
0 commit comments