@@ -47,6 +47,7 @@ def __init__(
4747 raise ValueError (
4848 f"Only affine symmetric and asymmetric cache types are supported: got { cache_type } "
4949 )
50+
5051 # For now supporting int8 only
5152 self .quantized_cache_dtype = torch .int8
5253 self .cache_fp_type = torch .float32
@@ -104,51 +105,78 @@ def update(self, input_pos, k_val, v_val):
104105 torch .int8 ,
105106 )
106107
107- if self .enable_dynamic_shape :
108- start_pos = input_pos [0 ].item ()
109- torch ._check_is_size (start_pos )
110- if self .is_transposed :
111- dim_to_slice = 2
108+ if self .is_transposed :
109+ # We cannot use update_cache op at the moment
110+ # if the cache is transposed
111+ # Also note that we shold not need separate paths
112+ # for dynamic shape vs !
113+ # Only reason it is done this way is to accommodate
114+ # for lowering pains of backends that work better
115+ # with index_put op.
116+ if self .enable_dynamic_shape :
117+ start_pos = input_pos [0 ].item ()
118+ torch ._check_is_size (start_pos )
119+ if self .is_transposed :
120+ dim_to_slice = 2
121+ else :
122+ dim_to_slice = 1
123+ torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
124+ seq_length = k_val .size (dim_to_slice )
125+ narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
126+ narrowed_k_scales = self .k_cache_scales .narrow (
127+ dim_to_slice , start_pos , seq_length
128+ )
129+ narrowed_k_zp = self .k_cache_zero_points .narrow (
130+ dim_to_slice , start_pos , seq_length
131+ )
132+ narrowed_k .copy_ (quantized_k_val )
133+ narrowed_k_scales .copy_ (k_scales )
134+ narrowed_k_zp .copy_ (k_zero_points )
135+ # pyre-ignore: Incompatible parameter type [6]
136+ narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
137+ narrowed_v_scales = self .v_cache_scales .narrow (
138+ dim_to_slice , start_pos , seq_length
139+ )
140+ narrowed_v_zp = self .v_cache_zero_points .narrow (
141+ dim_to_slice , start_pos , seq_length
142+ )
143+ narrowed_v .copy_ (quantized_v_val )
144+ narrowed_v_scales .copy_ (v_scales )
145+ narrowed_v_zp .copy_ (v_zero_points )
112146 else :
113- dim_to_slice = 1
114- torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
115- seq_length = k_val .size (dim_to_slice )
116- narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
117- narrowed_k_scales = self .k_cache_scales .narrow (
118- dim_to_slice , start_pos , seq_length
119- )
120- narrowed_k_zp = self .k_cache_zero_points .narrow (
121- dim_to_slice , start_pos , seq_length
122- )
123- narrowed_k .copy_ (quantized_k_val )
124- narrowed_k_scales .copy_ (k_scales )
125- narrowed_k_zp .copy_ (k_zero_points )
126- # pyre-ignore: Incompatible parameter type [6]
127- narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
128- narrowed_v_scales = self .v_cache_scales .narrow (
129- dim_to_slice , start_pos , seq_length
130- )
131- narrowed_v_zp = self .v_cache_zero_points .narrow (
132- dim_to_slice , start_pos , seq_length
133- )
134- narrowed_v .copy_ (quantized_v_val )
135- narrowed_v_scales .copy_ (v_scales )
136- narrowed_v_zp .copy_ (v_zero_points )
137- else :
138- if self .is_transposed :
139147 self .k_cache [:, :, input_pos ] = quantized_k_val
140148 self .k_cache_scales [:, :, input_pos ] = k_scales
141149 self .k_cache_zero_points [:, :, input_pos ] = k_zero_points
142150 self .v_cache [:, :, input_pos ] = quantized_v_val
143151 self .v_cache_scales [:, :, input_pos ] = v_scales
144152 self .v_cache_zero_points [:, :, input_pos ] = v_zero_points
145- else :
146- self .k_cache [:, input_pos ] = quantized_k_val
147- self .k_cache_scales [:, input_pos ] = k_scales
148- self .k_cache_zero_points [:, input_pos ] = k_zero_points
149- self .v_cache [:, input_pos ] = quantized_v_val
150- self .v_cache_scales [:, input_pos ] = v_scales
151- self .v_cache_zero_points [:, input_pos ] = v_zero_points
153+ else :
154+ # Right now using custom ops on this path.
155+ # In future we can update custom op to handle transposed cache
156+ # as well.
157+ # Note that we may have to revert this change if other ET
158+ # backends such as QNN want to use quantized cache, with dynamic shape,
159+ # instead of quantizing on their own.
160+ # But until this opting for code simplicity
161+ start_pos = input_pos [0 ].item ()
162+ _ = torch .ops .llama .update_quantized_cache (
163+ quantized_k_val , self .k_cache , start_pos
164+ )
165+ _ = torch .ops .llama .update_quantized_cache (
166+ k_scales , self .k_cache_scales , start_pos
167+ )
168+ _ = torch .ops .llama .update_quantized_cache (
169+ k_zero_points , self .k_cache_zero_points , start_pos
170+ )
171+ _ = torch .ops .llama .update_quantized_cache (
172+ quantized_v_val , self .v_cache , start_pos
173+ )
174+ _ = torch .ops .llama .update_quantized_cache (
175+ v_scales , self .v_cache_scales , start_pos
176+ )
177+ _ = torch .ops .llama .update_quantized_cache (
178+ v_zero_points , self .v_cache_zero_points , start_pos
179+ )
152180
153181 k_out = torch .ops .quantized_decomposed .dequantize_per_token (
154182 self .k_cache ,
0 commit comments