Skip to content

Commit a897dae

Browse files
larryliu0820facebook-github-bot
authored andcommitted
See what happens if we export with max_seq_len (#11611)
Summary: See if any CI is broken by this. Differential Revision: D76530379 Pulled By: larryliu0820
1 parent df4c12e commit a897dae

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

extension/llm/export/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _get_dynamic_shape(self) -> Any:
180180
if self.dynamic_shapes:
181181
return self.dynamic_shapes
182182

183-
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
183+
dim = torch.export.Dim("token_dim", max=self.max_seq_len)
184184
if self.enable_dynamic_shape:
185185
if not self.use_kv_cache:
186186
# Only one input argument: tokens

extension/llm/export/test/test_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_no_kv_cache(self) -> None:
6363
self.assertIsInstance(result[0], dict)
6464
self.assertIn(1, result[0])
6565
# Check that the value at key 1 is a torch.export.Dim with the correct max value
66-
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
66+
self.assertEqual(result[0][1].max, self.max_seq_len)
6767

6868
def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> None:
6969
"""Test _get_dynamic_shape when enable_dynamic_shape=True and use_kv_cache=True."""
@@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888
# Check first element (tokens dimension)
8989
self.assertIsInstance(result[0], dict)
9090
self.assertIn(1, result[0])
91-
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
91+
self.assertEqual(result[0][1].max, self.max_seq_len)
9292

9393
# Check second element (input_pos dimension)
9494
self.assertIsInstance(result[1], dict)

0 commit comments

Comments
 (0)