File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -180,7 +180,7 @@ def _get_dynamic_shape(self) -> Any:
180180 if self .dynamic_shapes :
181181 return self .dynamic_shapes
182182
183- dim = torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )
183+ dim = torch .export .Dim ("token_dim" , max = self .max_seq_len )
184184 if self .enable_dynamic_shape :
185185 if not self .use_kv_cache :
186186 # Only one input argument: tokens
Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_no_kv_cache(self) -> None:
6363 self .assertIsInstance (result [0 ], dict )
6464 self .assertIn (1 , result [0 ])
6565 # Check that the value at key 1 is a torch.export.Dim with the correct max value
66- self .assertEqual (result [0 ][1 ].max , self .max_seq_len - 1 )
66+ self .assertEqual (result [0 ][1 ].max , self .max_seq_len )
6767
6868 def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache (self ) -> None :
6969 """Test _get_dynamic_shape when enable_dynamic_shape=True and use_kv_cache=True."""
@@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888 # Check first element (tokens dimension)
8989 self .assertIsInstance (result [0 ], dict )
9090 self .assertIn (1 , result [0 ])
91- self .assertEqual (result [0 ][1 ].max , self .max_seq_len - 1 )
91+ self .assertEqual (result [0 ][1 ].max , self .max_seq_len )
9292
9393 # Check second element (input_pos dimension)
9494 self .assertIsInstance (result [1 ], dict )
You can’t perform that action at this time.
0 commit comments