diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 3c3243142cf..dfa357fe356 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -207,8 +207,8 @@ def _validate_update_cache_params( 1 ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" - torch._check((start_pos + seq_len) < cache.size(1)) - assert (start_pos + seq_len) < cache.size( + torch._check((start_pos + seq_len) <= cache.size(1)) + assert (start_pos + seq_len) <= cache.size( 1 ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index da5c3324662..f8c556f351c 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -144,12 +144,8 @@ def __init__( else: # Two input arguments: tokens and input_pos but input_pos is static shape. - # A runtime assertion is added by torch.ops.llama.update_cache requires that - # L['tokens'].size()[1] + input_pos[0].item() < self.max_seq_len - # This consttaint L['tokens'].size()[1] to be elf.max_seq_len-1 - # run with TORCH_LOGS=+dynamic for details self.dynamic_shapes = ( - {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, + {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, {"input_pos": {0: 1}}, ) diff --git a/extension/llm/export/test/test_builder.py b/extension/llm/export/test/test_builder.py index 7883480c1e7..8bf591813ec 100644 --- a/extension/llm/export/test/test_builder.py +++ b/extension/llm/export/test/test_builder.py @@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non # Check first element (tokens dimension) self.assertIsInstance(result[0], dict) self.assertIn(1, result[0]) - self.assertEqual(result[0][1].max, self.max_seq_len - 1) + self.assertEqual(result[0][1].max, self.max_seq_len) # Check second element (input_pos dimension) self.assertIsInstance(result[1], dict)