File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -142,9 +142,14 @@ def __init__(
142
142
{1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
143
143
)
144
144
else :
145
- # Two input arguments: tokens and input_pos but input_pos is static shape
145
+ # Two input arguments: tokens and input_pos but input_pos is static shape.
146
+
147
+ # A runtime assertion is added by torch.ops.llama.update_cache requires that
148
+ # L['tokens'].size()[1] + input_pos[0].item() < self.max_seq_len
149
+ # This consttaint L['tokens'].size()[1] to be elf.max_seq_len-1
150
+ # run with TORCH_LOGS=+dynamic for details
146
151
self .dynamic_shapes = (
147
- {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
152
+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
148
153
{"input_pos" : {0 : 1 }},
149
154
)
150
155
Original file line number Diff line number Diff line change @@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
88
88
# Check first element (tokens dimension)
89
89
self .assertIsInstance (result [0 ], dict )
90
90
self .assertIn (1 , result [0 ])
91
- self .assertEqual (result [0 ][1 ].max , self .max_seq_len )
91
+ self .assertEqual (result [0 ][1 ].max , self .max_seq_len - 1 )
92
92
93
93
# Check second element (input_pos dimension)
94
94
self .assertIsInstance (result [1 ], dict )
You can’t perform that action at this time.
0 commit comments