@@ -47,6 +47,7 @@ def __init__(
4747 raise ValueError (
4848 f"Only affine symmetric and asymmetric cache types are supported: got { cache_type } "
4949 )
50+
5051 # For now supporting int8 only
5152 self .quantized_cache_dtype = torch .int8
5253 self .cache_fp_type = torch .float32
@@ -65,10 +66,10 @@ def __init__(
6566 "v_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
6667 )
6768 self .register_buffer (
68- "k_cache_scales" , torch .ones (scale_shape , dtype = torch .double )
69+ "k_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
6970 )
7071 self .register_buffer (
71- "v_cache_scales" , torch .ones (scale_shape , dtype = torch .double )
72+ "v_cache_scales" , torch .ones (scale_shape , dtype = torch .float64 )
7273 )
7374 if cache_type == QuantizedCacheType .AffineAsymmetric :
7475 self .register_buffer (
@@ -100,47 +101,74 @@ def update(self, input_pos, k_val, v_val):
100101
101102 quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
102103
103- if self .enable_dynamic_shape :
104- start_pos = input_pos [0 ].item ()
105- torch ._check_is_size (start_pos )
106- dim_to_slice = 2 if self .is_transposed else 1
107- torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
108- seq_length = k_val .size (dim_to_slice )
109- narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
110- narrowed_k_scales = self .k_cache_scales .narrow (
111- dim_to_slice , start_pos , seq_length
112- )
113- narrowed_k_zp = self .k_cache_zero_points .narrow (
114- dim_to_slice , start_pos , seq_length
115- )
116- narrowed_k .copy_ (quantized_k_val )
117- narrowed_k_scales .copy_ (k_scales )
118- narrowed_k_zp .copy_ (k_zero_points )
119- narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
120- narrowed_v_scales = self .v_cache_scales .narrow (
121- dim_to_slice , start_pos , seq_length
122- )
123- narrowed_v_zp = self .v_cache_zero_points .narrow (
124- dim_to_slice , start_pos , seq_length
125- )
126- narrowed_v .copy_ (quantized_v_val )
127- narrowed_v_scales .copy_ (v_scales )
128- narrowed_v_zp .copy_ (v_zero_points )
129- else :
130- if self .is_transposed :
104+ if self .is_transposed :
105+ # We cannot use update_cache op at the moment
106+ # if the cache is transposed
107+ # Also note that we shold not need separate paths
108+ # for dynamic shape vs !
109+ # Only reason it is done this way is to accommodate
110+ # for lowering pains of backends that work better
111+ # with index_put op.
112+ if self .enable_dynamic_shape :
113+ start_pos = input_pos [0 ].item ()
114+ torch ._check_is_size (start_pos )
115+ dim_to_slice = 2 if self .is_transposed else 1
116+ torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
117+ seq_length = k_val .size (dim_to_slice )
118+ narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
119+ narrowed_k_scales = self .k_cache_scales .narrow (
120+ dim_to_slice , start_pos , seq_length
121+ )
122+ narrowed_k_zp = self .k_cache_zero_points .narrow (
123+ dim_to_slice , start_pos , seq_length
124+ )
125+ narrowed_k .copy_ (quantized_k_val )
126+ narrowed_k_scales .copy_ (k_scales )
127+ narrowed_k_zp .copy_ (k_zero_points )
128+ narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
129+ narrowed_v_scales = self .v_cache_scales .narrow (
130+ dim_to_slice , start_pos , seq_length
131+ )
132+ narrowed_v_zp = self .v_cache_zero_points .narrow (
133+ dim_to_slice , start_pos , seq_length
134+ )
135+ narrowed_v .copy_ (quantized_v_val )
136+ narrowed_v_scales .copy_ (v_scales )
137+ narrowed_v_zp .copy_ (v_zero_points )
138+ else :
131139 self .k_cache [:, :, input_pos ] = quantized_k_val
132140 self .k_cache_scales [:, :, input_pos ] = k_scales
133141 self .k_cache_zero_points [:, :, input_pos ] = k_zero_points
134142 self .v_cache [:, :, input_pos ] = quantized_v_val
135143 self .v_cache_scales [:, :, input_pos ] = v_scales
136144 self .v_cache_zero_points [:, :, input_pos ] = v_zero_points
137- else :
138- self .k_cache [:, input_pos ] = quantized_k_val
139- self .k_cache_scales [:, input_pos ] = k_scales
140- self .k_cache_zero_points [:, input_pos ] = k_zero_points
141- self .v_cache [:, input_pos ] = quantized_v_val
142- self .v_cache_scales [:, input_pos ] = v_scales
143- self .v_cache_zero_points [:, input_pos ] = v_zero_points
145+ else :
146+ # Right now using custom ops on this path.
147+ # In future we can update custom op to handle transposed cache
148+ # as well.
149+ # Note that we may have to revert this change if other ET
150+ # backends such as QNN want to use quantized cache, with dynamic shape,
151+ # instead of quantizing on their own.
152+ # But until this opting for code simplicity
153+ start_pos = input_pos [0 ].item ()
154+ _ = torch .ops .llama .update_quantized_cache (
155+ quantized_k_val , self .k_cache , start_pos
156+ )
157+ _ = torch .ops .llama .update_quantized_cache (
158+ k_scales , self .k_cache_scales , start_pos
159+ )
160+ _ = torch .ops .llama .update_quantized_cache (
161+ k_zero_points , self .k_cache_zero_points , start_pos
162+ )
163+ _ = torch .ops .llama .update_quantized_cache (
164+ quantized_v_val , self .v_cache , start_pos
165+ )
166+ _ = torch .ops .llama .update_quantized_cache (
167+ v_scales , self .v_cache_scales , start_pos
168+ )
169+ _ = torch .ops .llama .update_quantized_cache (
170+ v_zero_points , self .v_cache_zero_points , start_pos
171+ )
144172
145173 k_out = torch .ops .quantized_decomposed .dequantize_per_token (
146174 self .k_cache ,
0 commit comments