From 2baebac4167809f95ecf6a84cc022fad79cb8307 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 20 Sep 2024 15:50:46 -0700 Subject: [PATCH] [Executorch][llama] Update SDPA op to use quantized kv cache Using quantized kv cache, we cannot rely on sdpa to update the original case. SO we insert cache update op Differential Revision: [D62301841](https://our.internmc.facebook.com/intern/diff/D62301841/) [ghstack-poisoned] --- examples/models/llama2/export_llama_lib.py | 4 +- .../llama2/source_transformation/TARGETS | 3 + .../quantized_kv_cache.py | 104 +++++++++++------- .../llama2/source_transformation/sdpa.py | 21 +++- extension/llm/custom_ops/targets.bzl | 1 + 5 files changed, 90 insertions(+), 43 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 5a1d6035728..e3cce4631c3 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -474,9 +474,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: - assert (args.use_kv_cache is True) and ( - args.use_sdpa_with_kv_cache is False - ), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False" + assert args.use_kv_cache is True, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) if args.use_kv_cache: diff --git a/examples/models/llama2/source_transformation/TARGETS b/examples/models/llama2/source_transformation/TARGETS index a6dd37f5d4e..da28e735c8f 100644 --- a/examples/models/llama2/source_transformation/TARGETS +++ b/examples/models/llama2/source_transformation/TARGETS @@ -18,6 +18,9 @@ runtime.python_test( srcs = [ "test_quantized_kv_cache.py", ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:update_quantized_cache_aot_lib", + ], deps = [ ":quantized_kv_cache", "//caffe2:torch", diff --git a/examples/models/llama2/source_transformation/quantized_kv_cache.py b/examples/models/llama2/source_transformation/quantized_kv_cache.py index 43d6c4e251b..80d9574b22a 100644 --- a/examples/models/llama2/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama2/source_transformation/quantized_kv_cache.py @@ -40,6 +40,7 @@ def __init__( raise ValueError( f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" ) + # For now supporting int8 only self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 @@ -97,51 +98,78 @@ def update(self, input_pos, k_val, v_val): torch.int8, ) - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - if self.is_transposed: - dim_to_slice = 2 + if self.is_transposed: + # We cannot use update_cache op at the moment + # if the cache is transposed + # Also note that we shold not need separate paths + # for dynamic shape vs ! + # Only reason it is done this way is to accommodate + # for lowering pains of backends that work better + # with index_put op. + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + if self.is_transposed: + dim_to_slice = 2 + else: + dim_to_slice = 1 + torch._check(start_pos < self.k_cache.size(dim_to_slice)) + seq_length = k_val.size(dim_to_slice) + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_k_scales = self.k_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k_zp = self.k_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k.copy_(quantized_k_val) + narrowed_k_scales.copy_(k_scales) + narrowed_k_zp.copy_(k_zero_points) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_v_scales = self.v_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v_zp = self.v_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v.copy_(quantized_v_val) + narrowed_v_scales.copy_(v_scales) + narrowed_v_zp.copy_(v_zero_points) else: - dim_to_slice = 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k_scales = self.k_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k_zp = self.k_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k.copy_(quantized_k_val) - narrowed_k_scales.copy_(k_scales) - narrowed_k_zp.copy_(k_zero_points) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v_scales = self.v_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v_zp = self.v_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v.copy_(quantized_v_val) - narrowed_v_scales.copy_(v_scales) - narrowed_v_zp.copy_(v_zero_points) - else: - if self.is_transposed: self.k_cache[:, :, input_pos] = quantized_k_val self.k_cache_scales[:, :, input_pos] = k_scales self.k_cache_zero_points[:, :, input_pos] = k_zero_points self.v_cache[:, :, input_pos] = quantized_v_val self.v_cache_scales[:, :, input_pos] = v_scales self.v_cache_zero_points[:, :, input_pos] = v_zero_points - else: - self.k_cache[:, input_pos] = quantized_k_val - self.k_cache_scales[:, input_pos] = k_scales - self.k_cache_zero_points[:, input_pos] = k_zero_points - self.v_cache[:, input_pos] = quantized_v_val - self.v_cache_scales[:, input_pos] = v_scales - self.v_cache_zero_points[:, input_pos] = v_zero_points + else: + # Right now using custom ops on this path. + # In future we can update custom op to handle transposed cache + # as well. + # Note that we may have to revert this change if other ET + # backends such as QNN want to use quantized cache, with dynamic shape, + # instead of quantizing on their own. + # But until this opting for code simplicity + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_quantized_cache( + quantized_k_val, self.k_cache, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + k_scales, self.k_cache_scales, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + k_zero_points, self.k_cache_zero_points, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + quantized_v_val, self.v_cache, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + v_scales, self.v_cache_scales, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + v_zero_points, self.v_cache_zero_points, start_pos + ) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 0d2e4852e94..263a98a66b3 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -14,6 +14,9 @@ import torch from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA +from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import ( + QuantizedKVCache, +) class SDPACustom(torch.nn.Module): @@ -36,12 +39,26 @@ def forward( seqlen, mask, ): + k_cache = self.kv_cache.k_cache + v_cache = self.kv_cache.v_cache + if isinstance(self.kv_cache, QuantizedKVCache): + # updated quantize cache, scale and zero points + # returns dequantized kv cache + # Not most optimal. Optimizations to follow next + k_cache, v_cache = self.kv_cache.update(input_pos, k, v) + # Note that this path will still inplace mutate the k_cache, v_cache. + # WHen we are not using quantized kv cache, this will just mutate + # the original kv cache. + # When we aer using quantized kv cache, this will mutate + # k_cache, v_cache that is returned from cache update operation. + # This operation just dequantized thee cache and returns that. + # Future diffs will optimize this output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, + k_cache, + v_cache, input_pos[-1].item(), seqlen, None, # Attention mask diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 503e4a0c7bd..a5bf280d76f 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -20,6 +20,7 @@ def define_common_targets(): "op_sdpa.h", ], exported_deps = [ + ":update_quantized_cache", "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/optimized:libblas{}".format(mkl_dep),