From 7b9ab9292bc9a73e8aafa5f459eb31f6d8b978f9 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 1 Jul 2025 06:48:46 -0700 Subject: [PATCH 1/3] [Executorch][llm] Make runner return error if execution was not successful At the moment we continue execution and the stack fails later on as I found when running with quantize kv cache + ring attention Differential Revision: [D77516822](https://our.internmc.facebook.com/intern/diff/D77516822/) ghstack-source-id: 293635304 Pull Request resolved: https://github.com/pytorch/executorch/pull/12129 --- examples/models/llama/main.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 5d34bf932e7..25b840f260b 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -100,12 +100,20 @@ int32_t main(int32_t argc, char** argv) { } if (warmup) { - runner->warmup(prompt, /*max_new_tokens=*/seq_len); + auto error = runner->warmup(prompt, /*max_new_tokens=*/seq_len); + if (error != executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to warmup llama runner"); + return 1; + } } // generate executorch::extension::llm::GenerationConfig config{ .seq_len = seq_len, .temperature = temperature}; - runner->generate(prompt, config); + auto error = runner->generate(prompt, config); + if (error != executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to warmup llama runner"); + return 1; + } return 0; } From 09d57aa6418a8fec4755cc197c056a23f7afab5c Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 1 Jul 2025 09:38:47 -0700 Subject: [PATCH 2/3] [Executorch][llm] Make mask tensor float only for sdpa Now that we support quantized sdpa query tensor can be quantized and attention mask can be float (the only type allowed). So this check doesnt make sense anymore. Differential Revision: [D77516821](https://our.internmc.facebook.com/intern/diff/D77516821/) ghstack-source-id: 293661338 Pull Request resolved: https://github.com/pytorch/executorch/pull/12131 --- extension/llm/custom_ops/op_sdpa.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 91802a8445d..c98fa1729fa 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -59,8 +59,8 @@ bool validate_flash_attention_args( ET_CHECK_OR_RETURN_FALSE( !attn_mask.has_value() || - attn_mask.value().scalar_type() == query.scalar_type(), - "Attention mask must be a 2D tensor"); + attn_mask.value().scalar_type() == ScalarType::Float, + "Attention mask must be a Float tensor"); ET_CHECK_OR_RETURN_FALSE( is_contiguous_dim_order(query.dim_order().data(), query.dim()), From e7a0c585f62435179fa0964e1bd4d74ab2435466 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 1 Jul 2025 09:38:51 -0700 Subject: [PATCH 3/3] [Executorch][llm] Fix ring kv cache when used with quantized kv cache and sdpa When using quantized kv cache and SDPA, there was two bugs: 1. It did not reset return_float_values of QuantizedRingKVCache. Which results in QuantizedKVCache returning float values post dequant. 2. For quantized kv cache, SDPA module stores kv_cache that is owned by attention module. When replacing kv cache in Attention we have to make sure that we change the reference in SDPA as well. Differential Revision: [D77516823](https://our.internmc.facebook.com/intern/diff/D77516823/) ghstack-source-id: 293661340 Pull Request resolved: https://github.com/pytorch/executorch/pull/12132 --- .../llama/source_transformation/custom_kv_cache.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 25ec207d0e0..0fbdd1936ef 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -43,6 +43,7 @@ def __init__( head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, use_custom_update_cache_op: bool = False, + return_float_values: bool = True, ): super().__init__() if cache_type not in ( @@ -57,7 +58,7 @@ def __init__( self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 - self.return_float_values = True + self.return_float_values = return_float_values self.max_context_length = max_context_length cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) scale_shape = (max_batch_size, max_context_length, n_heads, 1) @@ -400,6 +401,7 @@ def __init__( head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, use_custom_update_cache_op: bool = False, + return_float_values: bool = True, ): # Look at attention.py for explanation on why max_context_length * 2 super().__init__( @@ -409,6 +411,7 @@ def __init__( head_dim, cache_type, use_custom_update_cache_op, + return_float_values, ) self.cache_positions_manager = CachePositionsManager(self.max_context_length) self.is_ring_buffer = True @@ -459,6 +462,7 @@ def from_quantized_kv_cache( head_dim, kv_cache.cache_type, kv_cache.use_custom_update_cache_op, + kv_cache.return_float_values, ) @@ -583,4 +587,12 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): # it is not doing causal attention if "SDPACustom" in attention.SDPA.__class__.__name__: attention.SDPA.use_attention_mask = True + # QuantizedSDPA has to store kv_cache in order to obtrain + # scales and zero points for k and v cache. + # So if we replcaed attention module's quantized kv cache with + # QuantizedRingKVCache then we also have to replace attention's + # SDPA module kv_cache so that it refers to the same kv_cache + if "QuantizedSDPA" in attention.SDPA.__class__.__name__: + attention.SDPA.use_attention_mask = True + attention.SDPA.kv_cache = attention.kv_cache return module