Skip to content

Commit 2e2ad98

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Fix max seq length bug
Summary: update cache op was forcing incorrect constraint on sequence length. Fixing that alongw with fixing export allows us to correctly export the model Differential Revision: D84562463
1 parent 19c9ff3 commit 2e2ad98

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

extension/llm/custom_ops/custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def _validate_update_cache_params(
207207
1
208208
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
209209

210-
torch._check((start_pos + seq_len) < cache.size(1))
211-
assert (start_pos + seq_len) < cache.size(
210+
torch._check((start_pos + seq_len) <= cache.size(1))
211+
assert (start_pos + seq_len) <= cache.size(
212212
1
213213
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
214214

extension/llm/export/builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,8 @@ def __init__(
144144
else:
145145
# Two input arguments: tokens and input_pos but input_pos is static shape.
146146

147-
# A runtime assertion is added by torch.ops.llama.update_cache requires that
148-
# L['tokens'].size()[1] + input_pos[0].item() < self.max_seq_len
149-
# This consttaint L['tokens'].size()[1] to be elf.max_seq_len-1
150-
# run with TORCH_LOGS=+dynamic for details
151147
self.dynamic_shapes = (
152-
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
148+
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
153149
{"input_pos": {0: 1}},
154150
)
155151

0 commit comments

Comments
 (0)