@@ -40,6 +40,7 @@ def __init__(
4040 raise ValueError (
4141 f"Only affine symmetric and asymmetric cache types are supported: got { cache_type } "
4242 )
43+
4344 # For now supporting int8 only
4445 self .quantized_cache_dtype = torch .int8
4546 self .cache_fp_type = torch .float32
@@ -97,51 +98,78 @@ def update(self, input_pos, k_val, v_val):
9798 torch .int8 ,
9899 )
99100
100- if self .enable_dynamic_shape :
101- start_pos = input_pos [0 ].item ()
102- torch ._check_is_size (start_pos )
103- if self .is_transposed :
104- dim_to_slice = 2
101+ if self .is_transposed :
102+ # We cannot use update_cache op at the moment
103+ # if the cache is transposed
104+ # Also note that we shold not need separate paths
105+ # for dynamic shape vs !
106+ # Only reason it is done this way is to accommodate
107+ # for lowering pains of backends that work better
108+ # with index_put op.
109+ if self .enable_dynamic_shape :
110+ start_pos = input_pos [0 ].item ()
111+ torch ._check_is_size (start_pos )
112+ if self .is_transposed :
113+ dim_to_slice = 2
114+ else :
115+ dim_to_slice = 1
116+ torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
117+ seq_length = k_val .size (dim_to_slice )
118+ narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
119+ narrowed_k_scales = self .k_cache_scales .narrow (
120+ dim_to_slice , start_pos , seq_length
121+ )
122+ narrowed_k_zp = self .k_cache_zero_points .narrow (
123+ dim_to_slice , start_pos , seq_length
124+ )
125+ narrowed_k .copy_ (quantized_k_val )
126+ narrowed_k_scales .copy_ (k_scales )
127+ narrowed_k_zp .copy_ (k_zero_points )
128+ # pyre-ignore: Incompatible parameter type [6]
129+ narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
130+ narrowed_v_scales = self .v_cache_scales .narrow (
131+ dim_to_slice , start_pos , seq_length
132+ )
133+ narrowed_v_zp = self .v_cache_zero_points .narrow (
134+ dim_to_slice , start_pos , seq_length
135+ )
136+ narrowed_v .copy_ (quantized_v_val )
137+ narrowed_v_scales .copy_ (v_scales )
138+ narrowed_v_zp .copy_ (v_zero_points )
105139 else :
106- dim_to_slice = 1
107- torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
108- seq_length = k_val .size (dim_to_slice )
109- narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
110- narrowed_k_scales = self .k_cache_scales .narrow (
111- dim_to_slice , start_pos , seq_length
112- )
113- narrowed_k_zp = self .k_cache_zero_points .narrow (
114- dim_to_slice , start_pos , seq_length
115- )
116- narrowed_k .copy_ (quantized_k_val )
117- narrowed_k_scales .copy_ (k_scales )
118- narrowed_k_zp .copy_ (k_zero_points )
119- # pyre-ignore: Incompatible parameter type [6]
120- narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
121- narrowed_v_scales = self .v_cache_scales .narrow (
122- dim_to_slice , start_pos , seq_length
123- )
124- narrowed_v_zp = self .v_cache_zero_points .narrow (
125- dim_to_slice , start_pos , seq_length
126- )
127- narrowed_v .copy_ (quantized_v_val )
128- narrowed_v_scales .copy_ (v_scales )
129- narrowed_v_zp .copy_ (v_zero_points )
130- else :
131- if self .is_transposed :
132140 self .k_cache [:, :, input_pos ] = quantized_k_val
133141 self .k_cache_scales [:, :, input_pos ] = k_scales
134142 self .k_cache_zero_points [:, :, input_pos ] = k_zero_points
135143 self .v_cache [:, :, input_pos ] = quantized_v_val
136144 self .v_cache_scales [:, :, input_pos ] = v_scales
137145 self .v_cache_zero_points [:, :, input_pos ] = v_zero_points
138- else :
139- self .k_cache [:, input_pos ] = quantized_k_val
140- self .k_cache_scales [:, input_pos ] = k_scales
141- self .k_cache_zero_points [:, input_pos ] = k_zero_points
142- self .v_cache [:, input_pos ] = quantized_v_val
143- self .v_cache_scales [:, input_pos ] = v_scales
144- self .v_cache_zero_points [:, input_pos ] = v_zero_points
146+ else :
147+ # Right now using custom ops on this path.
148+ # In future we can update custom op to handle transposed cache
149+ # as well.
150+ # Note that we may have to revert this change if other ET
151+ # backends such as QNN want to use quantized cache, with dynamic shape,
152+ # instead of quantizing on their own.
153+ # But until this opting for code simplicity
154+ start_pos = input_pos [0 ].item ()
155+ _ = torch .ops .llama .update_quantized_cache (
156+ quantized_k_val , self .k_cache , start_pos
157+ )
158+ _ = torch .ops .llama .update_quantized_cache (
159+ k_scales , self .k_cache_scales , start_pos
160+ )
161+ _ = torch .ops .llama .update_quantized_cache (
162+ k_zero_points , self .k_cache_zero_points , start_pos
163+ )
164+ _ = torch .ops .llama .update_quantized_cache (
165+ quantized_v_val , self .v_cache , start_pos
166+ )
167+ _ = torch .ops .llama .update_quantized_cache (
168+ v_scales , self .v_cache_scales , start_pos
169+ )
170+ _ = torch .ops .llama .update_quantized_cache (
171+ v_zero_points , self .v_cache_zero_points , start_pos
172+ )
145173
146174 k_out = torch .ops .quantized_decomposed .dequantize_per_token (
147175 self .k_cache ,
0 commit comments