@@ -48,6 +48,9 @@ def __init__(
4848 f"Only affine symmetric and asymmetric cache types are supported: got { cache_type } "
4949 )
5050
51+ # This reduces dequantize overhead by dequantizing only upto the
52+ # current input position.
53+ self .return_sliced_cache = True
5154 # For now supporting int8 only
5255 self .use_custom_update_cache_op = use_custom_update_cache_op
5356 self .quantized_cache_dtype = torch .int8
@@ -125,19 +128,34 @@ def update(self, input_pos, k_val, v_val):
125128 self .v_cache_scales [:, input_pos ] = v_scales
126129 self .v_cache_zero_points [:, input_pos ] = v_zero_points
127130
131+ if self .return_sliced_cache :
132+ k_cache_out = self .k_cache [:, :input_pos + 1 ]
133+ k_cache_scales = self .k_cache_scales [:, :input_pos + 1 ]
134+ k_cache_zero_points = self .k_cache_zero_points [:, :input_pos + 1 ]
135+ v_cache_out = self .v_cache [:, :input_pos + 1 ]
136+ v_cache_scales = self .v_cache_scales [:, :input_pos + 1 ]
137+ v_cache_zero_points = self .v_cache_zero_points [:, :input_pos + 1 ]
138+ else :
139+ k_cache_out = self .k_cache
140+ k_cache_scales = self .k_cache_scales
141+ k_cache_zero_points = self .k_cache_zero_points
142+ v_cache_out = self .v_cache
143+ v_cache_scales = self .v_cache_scales
144+ v_cache_zero_points = self .v_cache_zero_points
145+
128146 k_out = torch .ops .quantized_decomposed .dequantize_per_token (
129- self . k_cache ,
130- self . k_cache_scales ,
131- self . k_cache_zero_points ,
147+ k_cache_out ,
148+ k_cache_scales ,
149+ k_cache_zero_points ,
132150 torch .iinfo (self .quantized_cache_dtype ).min ,
133151 torch .iinfo (self .quantized_cache_dtype ).max ,
134152 self .quantized_cache_dtype ,
135153 self .cache_fp_type ,
136154 )
137155 v_out = torch .ops .quantized_decomposed .dequantize_per_token (
138- self . v_cache ,
139- self . v_cache_scales ,
140- self . v_cache_zero_points ,
156+ v_cache_out ,
157+ v_cache_scales ,
158+ v_cache_zero_points ,
141159 torch .iinfo (self .quantized_cache_dtype ).min ,
142160 torch .iinfo (self .quantized_cache_dtype ).max ,
143161 self .quantized_cache_dtype ,
0 commit comments