From 07e9d95bf2008dc89ac4635c9ed19d094c7365b2 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 17 Mar 2025 21:38:05 -0700 Subject: [PATCH] [Executorch][kv cache] Make quantized cache return only the updated cache portion Differential Revision: [D69856554](https://our.internmc.facebook.com/intern/diff/D69856554/) [ghstack-poisoned] --- .../quantized_kv_cache.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index e7138622ed9..91b20610e66 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -48,6 +48,9 @@ def __init__( f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" ) + # This reduces dequantize overhead by dequantizing only upto the + # current input position. + self.return_sliced_cache = True # For now supporting int8 only self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 @@ -125,19 +128,34 @@ def update(self, input_pos, k_val, v_val): self.v_cache_scales[:, input_pos] = v_scales self.v_cache_zero_points[:, input_pos] = v_zero_points + if self.return_sliced_cache: + k_cache_out = self.k_cache[:, :input_pos+1] + k_cache_scales = self.k_cache_scales[:, :input_pos+1] + k_cache_zero_points = self.k_cache_zero_points[:, :input_pos+1] + v_cache_out = self.v_cache[:, :input_pos+1] + v_cache_scales = self.v_cache_scales[:, :input_pos+1] + v_cache_zero_points = self.v_cache_zero_points[:, :input_pos+1] + else: + k_cache_out = self.k_cache + k_cache_scales = self.k_cache_scales + k_cache_zero_points = self.k_cache_zero_points + v_cache_out = self.v_cache + v_cache_scales = self.v_cache_scales + v_cache_zero_points = self.v_cache_zero_points + k_out = torch.ops.quantized_decomposed.dequantize_per_token( - self.k_cache, - self.k_cache_scales, - self.k_cache_zero_points, + k_cache_out, + k_cache_scales, + k_cache_zero_points, torch.iinfo(self.quantized_cache_dtype).min, torch.iinfo(self.quantized_cache_dtype).max, self.quantized_cache_dtype, self.cache_fp_type, ) v_out = torch.ops.quantized_decomposed.dequantize_per_token( - self.v_cache, - self.v_cache_scales, - self.v_cache_zero_points, + v_cache_out, + v_cache_scales, + v_cache_zero_points, torch.iinfo(self.quantized_cache_dtype).min, torch.iinfo(self.quantized_cache_dtype).max, self.quantized_cache_dtype,