Skip to content

Commit deb42f2

Browse files
authored
update lama export DS specs to be more accurate.
Differential Revision: D83708583 Pull Request resolved: #14737
1 parent 0882c9b commit deb42f2

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

extension/llm/export/builder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,14 @@ def __init__(
142142
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
143143
)
144144
else:
145-
# Two input arguments: tokens and input_pos but input_pos is static shape
145+
# Two input arguments: tokens and input_pos but input_pos is static shape.
146+
147+
# A runtime assertion is added by torch.ops.llama.update_cache requires that
148+
# L['tokens'].size()[1] + input_pos[0].item() < self.max_seq_len
149+
# This consttaint L['tokens'].size()[1] to be elf.max_seq_len-1
150+
# run with TORCH_LOGS=+dynamic for details
146151
self.dynamic_shapes = (
147-
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
152+
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
148153
{"input_pos": {0: 1}},
149154
)
150155

extension/llm/export/test/test_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888
# Check first element (tokens dimension)
8989
self.assertIsInstance(result[0], dict)
9090
self.assertIn(1, result[0])
91-
self.assertEqual(result[0][1].max, self.max_seq_len)
91+
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
9292

9393
# Check second element (input_pos dimension)
9494
self.assertIsInstance(result[1], dict)

0 commit comments

Comments
 (0)