From 64e93971ecd5ab4fa2399da91c581b2f29a7bd09 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 14 Oct 2025 20:14:00 -0700 Subject: [PATCH] Fix max seq length bug Differential Revision: D84562463 Pull Request resolved: https://github.com/pytorch/executorch/pull/15084 (cherry picked from commit f1e25484fdaaf9453a1a0a61e5d2025080c9ebd8) --- extension/llm/custom_ops/custom_ops.py | 4 ++-- extension/llm/export/builder.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 3c3243142cf..dfa357fe356 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -207,8 +207,8 @@ def _validate_update_cache_params( 1 ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" - torch._check((start_pos + seq_len) < cache.size(1)) - assert (start_pos + seq_len) < cache.size( + torch._check((start_pos + seq_len) <= cache.size(1)) + assert (start_pos + seq_len) <= cache.size( 1 ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 01000f3564c..f8c556f351c 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -142,7 +142,8 @@ def __init__( {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, ) else: - # Two input arguments: tokens and input_pos but input_pos is static shape + # Two input arguments: tokens and input_pos but input_pos is static shape. + self.dynamic_shapes = ( {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, {"input_pos": {0: 1}},