@@ -37,8 +37,6 @@ def __init__(
3737 n_heads ,
3838 head_dim ,
3939 cache_type : QuantizedCacheType = QuantizedCacheType .AffineSymmetric ,
40- tranposed = False ,
41- enable_dynamic_shape = False ,
4240 ):
4341 super ().__init__ ()
4442 if cache_type not in (
@@ -52,14 +50,8 @@ def __init__(
5250 # For now supporting int8 only
5351 self .quantized_cache_dtype = torch .int8
5452 self .cache_fp_type = torch .float32
55- self .is_transposed = tranposed
56- self .enable_dynamic_shape = enable_dynamic_shape
57- if self .is_transposed :
58- cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
59- scale_shape = (max_batch_size , n_heads , max_seq_length , 1 )
60- else :
61- cache_shape = (max_batch_size , max_seq_length , n_heads , head_dim )
62- scale_shape = (max_batch_size , max_seq_length , n_heads , 1 )
53+ cache_shape = (max_batch_size , max_seq_length , n_heads , head_dim )
54+ scale_shape = (max_batch_size , max_seq_length , n_heads , 1 )
6355 self .register_buffer (
6456 "k_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
6557 )
@@ -98,71 +90,37 @@ def _quantize(self, value):
9890 return quantized_value , scales , zero_points
9991
10092 def update (self , input_pos , k_val , v_val ):
93+ """
94+ k_val, v_val: [B, H, S, D]
95+ return: [B, H, S, D]
96+ However the storage is [B, S, H, D] so we incur transpose in, transpose out
97+ This shall be removed by subsequent post-export graph pass
98+ """
99+ k_val = k_val .transpose (1 , 2 )
100+ v_val = v_val .transpose (1 , 2 )
101101 # quantize current k_val and store it in the cache
102102 quantized_k_val , k_scales , k_zero_points = self ._quantize (k_val )
103103
104104 quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
105105
106- if self .is_transposed :
107- # We cannot use update_cache op at the moment
108- # if the cache is transposed
109- # Also note that we shold not need separate paths
110- # for dynamic shape vs !
111- # Only reason it is done this way is to accommodate
112- # for lowering pains of backends that work better
113- # with index_put op.
114- if self .enable_dynamic_shape :
115- start_pos = input_pos [0 ].item ()
116- torch ._check_is_size (start_pos )
117- dim_to_slice = 2 if self .is_transposed else 1
118- torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
119- seq_length = k_val .size (dim_to_slice )
120- narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
121- narrowed_k_scales = self .k_cache_scales .narrow (
122- dim_to_slice , start_pos , seq_length
123- )
124- narrowed_k_zp = self .k_cache_zero_points .narrow (
125- dim_to_slice , start_pos , seq_length
126- )
127- narrowed_k .copy_ (quantized_k_val )
128- narrowed_k_scales .copy_ (k_scales )
129- narrowed_k_zp .copy_ (k_zero_points )
130- narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
131- narrowed_v_scales = self .v_cache_scales .narrow (
132- dim_to_slice , start_pos , seq_length
133- )
134- narrowed_v_zp = self .v_cache_zero_points .narrow (
135- dim_to_slice , start_pos , seq_length
136- )
137- narrowed_v .copy_ (quantized_v_val )
138- narrowed_v_scales .copy_ (v_scales )
139- narrowed_v_zp .copy_ (v_zero_points )
140- else :
141- self .k_cache [:, :, input_pos ] = quantized_k_val
142- self .k_cache_scales [:, :, input_pos ] = k_scales
143- self .k_cache_zero_points [:, :, input_pos ] = k_zero_points
144- self .v_cache [:, :, input_pos ] = quantized_v_val
145- self .v_cache_scales [:, :, input_pos ] = v_scales
146- self .v_cache_zero_points [:, :, input_pos ] = v_zero_points
147- else :
148- # Right now using custom ops on this path.
149- # In future we can update custom op to handle transposed cache
150- # as well.
151- # Note that we may have to revert this change if other ET
152- # backends such as QNN want to use quantized cache, with dynamic shape,
153- # instead of quantizing on their own.
154- # But until this opting for code simplicity
155- start_pos = input_pos [0 ].item ()
156- _ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
157- _ = torch .ops .llama .update_cache (k_scales , self .k_cache_scales , start_pos )
158- _ = torch .ops .llama .update_cache (
159- k_zero_points , self .k_cache_zero_points , start_pos
160- )
161- _ = torch .ops .llama .update_cache (quantized_v_val , self .v_cache , start_pos )
162- _ = torch .ops .llama .update_cache (v_scales , self .v_cache_scales , start_pos )
163- _ = torch .ops .llama .update_cache (
164- v_zero_points , self .v_cache_zero_points , start_pos
165- )
106+ # Right now using custom ops on this path.
107+ # In future we can update custom op to handle transposed cache
108+ # as well.
109+ # Note that we may have to revert this change if other ET
110+ # backends such as QNN want to use quantized cache, with dynamic shape,
111+ # instead of quantizing on their own.
112+ # But until this opting for code simplicity
113+ start_pos = input_pos [0 ].item ()
114+ _ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
115+ _ = torch .ops .llama .update_cache (k_scales , self .k_cache_scales , start_pos )
116+ _ = torch .ops .llama .update_cache (
117+ k_zero_points , self .k_cache_zero_points , start_pos
118+ )
119+ _ = torch .ops .llama .update_cache (quantized_v_val , self .v_cache , start_pos )
120+ _ = torch .ops .llama .update_cache (v_scales , self .v_cache_scales , start_pos )
121+ _ = torch .ops .llama .update_cache (
122+ v_zero_points , self .v_cache_zero_points , start_pos
123+ )
166124
167125 k_out = torch .ops .quantized_decomposed .dequantize_per_token (
168126 self .k_cache ,
@@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val):
183141 self .cache_fp_type ,
184142 )
185143
186- if self .is_transposed :
187- if self .enable_dynamic_shape :
188- start_pos = input_pos [0 ].item ()
189- torch ._check_is_size (start_pos )
190- dim_to_slice = 2 if self .is_transposed else 1
191- torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
192- seq_length = k_val .size (dim_to_slice )
193- narrowed_k = k_out .narrow (dim_to_slice , start_pos , seq_length )
194- narrowed_k .copy_ (k_val )
195- narrowed_v = v_out .narrow (dim_to_slice , start_pos , seq_length )
196- narrowed_v .copy_ (v_val )
197- else :
198- k_out [:, :, input_pos ] = k_val
199- v_out [:, :, input_pos ] = v_val
200- else :
201- start_pos = input_pos [0 ].item ()
202- _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
203- _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
144+ start_pos = input_pos [0 ].item ()
145+ _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
146+ _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
204147
205- return k_out , v_out
148+ return k_out . transpose ( 1 , 2 ), v_out . transpose ( 1 , 2 )
206149
207150 @classmethod
208151 def from_float (cls , kv_cache , cache_type : QuantizedCacheType ):
209- cache_shape = kv_cache .k_cache .shape
210- if kv_cache .is_transposed :
211- max_batch_size , n_heads , max_seq_length , head_dim = cache_shape
212- else :
213- max_batch_size , max_seq_length , n_heads , head_dim = cache_shape
152+ max_batch_size , n_heads , max_seq_length , head_dim = kv_cache .k_cache .shape
153+ if isinstance (kv_cache , CustomKVCache ):
154+ # If replacing custom kv cache, then the shape is [B, S, H, D]
155+ max_batch_size , max_seq_length , n_heads , head_dim = kv_cache .k_cache .shape
214156 return cls (
215157 max_batch_size ,
216158 max_seq_length ,
217159 n_heads ,
218160 head_dim ,
219161 cache_type ,
220- kv_cache .is_transposed ,
221- kv_cache .enable_dynamic_shape ,
222162 )
223163
224164
@@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
254194 "Replacing KVCache with QuantizedKVCache. This modifies the model in place."
255195 )
256196 for name , child in module .named_children ():
257- if isinstance (child , KVCache ):
197+ if isinstance (child , KVCache ) or isinstance ( child , CustomKVCache ) :
258198 setattr (
259199 module ,
260200 name ,
@@ -291,11 +231,13 @@ def __init__(
291231 def update (
292232 self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
293233 ) -> Tuple [torch .Tensor , torch .Tensor ]:
294- # input_pos: [S], k_val: [B, S, H, D]
234+ # input_pos: [S], k_val: [B, H, S, D]
235+ k_val = k_val .transpose (1 , 2 )
236+ v_val = v_val .transpose (1 , 2 )
295237 start_pos = input_pos [0 ].item ()
296238 _ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
297239 _ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
298- return self .k_cache , self .v_cache
240+ return self .k_cache . transpose ( 1 , 2 ), self .v_cache . transpose ( 1 , 2 )
299241
300242
301243def replace_kv_cache_with_custom_kv_cache (module ):
@@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
313255 if isinstance (child , KVCache ):
314256 cache_shape = child .k_cache .shape
315257 cache_dtype = child .k_cache .dtype
316- assert (
317- child .is_transposed is False
318- ), "CustomKVCache does not support transposed cache"
319- max_batch_size , max_seq_length , n_heads , head_dim = cache_shape
258+ max_batch_size , n_heads , max_seq_length , head_dim = cache_shape
320259 setattr (
321260 module ,
322261 name ,
0 commit comments