diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index e7138622ed9..91b20610e66 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -48,6 +48,9 @@ def __init__( f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" ) + # This reduces dequantize overhead by dequantizing only upto the + # current input position. + self.return_sliced_cache = True # For now supporting int8 only self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 @@ -125,19 +128,34 @@ def update(self, input_pos, k_val, v_val): self.v_cache_scales[:, input_pos] = v_scales self.v_cache_zero_points[:, input_pos] = v_zero_points + if self.return_sliced_cache: + k_cache_out = self.k_cache[:, :input_pos+1] + k_cache_scales = self.k_cache_scales[:, :input_pos+1] + k_cache_zero_points = self.k_cache_zero_points[:, :input_pos+1] + v_cache_out = self.v_cache[:, :input_pos+1] + v_cache_scales = self.v_cache_scales[:, :input_pos+1] + v_cache_zero_points = self.v_cache_zero_points[:, :input_pos+1] + else: + k_cache_out = self.k_cache + k_cache_scales = self.k_cache_scales + k_cache_zero_points = self.k_cache_zero_points + v_cache_out = self.v_cache + v_cache_scales = self.v_cache_scales + v_cache_zero_points = self.v_cache_zero_points + k_out = torch.ops.quantized_decomposed.dequantize_per_token( - self.k_cache, - self.k_cache_scales, - self.k_cache_zero_points, + k_cache_out, + k_cache_scales, + k_cache_zero_points, torch.iinfo(self.quantized_cache_dtype).min, torch.iinfo(self.quantized_cache_dtype).max, self.quantized_cache_dtype, self.cache_fp_type, ) v_out = torch.ops.quantized_decomposed.dequantize_per_token( - self.v_cache, - self.v_cache_scales, - self.v_cache_zero_points, + v_cache_out, + v_cache_scales, + v_cache_zero_points, torch.iinfo(self.quantized_cache_dtype).min, torch.iinfo(self.quantized_cache_dtype).max, self.quantized_cache_dtype,