From 7b9ab9292bc9a73e8aafa5f459eb31f6d8b978f9 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 1 Jul 2025 06:48:46 -0700 Subject: [PATCH 1/2] [Executorch][llm] Make runner return error if execution was not successful At the moment we continue execution and the stack fails later on as I found when running with quantize kv cache + ring attention Differential Revision: [D77516822](https://our.internmc.facebook.com/intern/diff/D77516822/) ghstack-source-id: 293635304 Pull Request resolved: https://github.com/pytorch/executorch/pull/12129 --- examples/models/llama/main.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 5d34bf932e7..25b840f260b 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -100,12 +100,20 @@ int32_t main(int32_t argc, char** argv) { } if (warmup) { - runner->warmup(prompt, /*max_new_tokens=*/seq_len); + auto error = runner->warmup(prompt, /*max_new_tokens=*/seq_len); + if (error != executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to warmup llama runner"); + return 1; + } } // generate executorch::extension::llm::GenerationConfig config{ .seq_len = seq_len, .temperature = temperature}; - runner->generate(prompt, config); + auto error = runner->generate(prompt, config); + if (error != executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to warmup llama runner"); + return 1; + } return 0; } From 09d57aa6418a8fec4755cc197c056a23f7afab5c Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 1 Jul 2025 09:38:47 -0700 Subject: [PATCH 2/2] [Executorch][llm] Make mask tensor float only for sdpa Now that we support quantized sdpa query tensor can be quantized and attention mask can be float (the only type allowed). So this check doesnt make sense anymore. Differential Revision: [D77516821](https://our.internmc.facebook.com/intern/diff/D77516821/) ghstack-source-id: 293661338 Pull Request resolved: https://github.com/pytorch/executorch/pull/12131 --- extension/llm/custom_ops/op_sdpa.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 91802a8445d..c98fa1729fa 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -59,8 +59,8 @@ bool validate_flash_attention_args( ET_CHECK_OR_RETURN_FALSE( !attn_mask.has_value() || - attn_mask.value().scalar_type() == query.scalar_type(), - "Attention mask must be a 2D tensor"); + attn_mask.value().scalar_type() == ScalarType::Float, + "Attention mask must be a Float tensor"); ET_CHECK_OR_RETURN_FALSE( is_contiguous_dim_order(query.dim_order().data(), query.dim()),