Skip to content

Commit 1ddd912

Browse files
Fix max seq length bug (#15141)
Summary: update cache op was forcing incorrect constraint on sequence length. Fixing that alongw with fixing export allows us to correctly export the model Differential Revision: D84562463 Co-authored-by: Kimish Patel <[email protected]>
1 parent e8660d0 commit 1ddd912

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

extension/llm/custom_ops/custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def _validate_update_cache_params(
207207
1
208208
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
209209

210-
torch._check((start_pos + seq_len) < cache.size(1))
211-
assert (start_pos + seq_len) < cache.size(
210+
torch._check((start_pos + seq_len) <= cache.size(1))
211+
assert (start_pos + seq_len) <= cache.size(
212212
1
213213
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
214214

extension/llm/export/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def __init__(
142142
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
143143
)
144144
else:
145-
# Two input arguments: tokens and input_pos but input_pos is static shape
145+
# Two input arguments: tokens and input_pos but input_pos is static shape.
146+
146147
self.dynamic_shapes = (
147148
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
148149
{"input_pos": {0: 1}},

0 commit comments

Comments
 (0)