File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
server/text_generation_server/models Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -162,7 +162,7 @@ def from_pb(
162
162
163
163
# Padded all_input_ids_tensor; the maximum length of any sequence is the max
164
164
# (padded) input sequence length + the max output length
165
- all_input_ids_tensor = torch . full (
165
+ all_input_ids_tensor = all_input_ids . new_full (
166
166
(batch_size , tokenize_length + padding_right_offset ),
167
167
tokenizer .pad_token_id ,
168
168
)
@@ -710,9 +710,10 @@ def generate_token(
710
710
if first and not for_concat :
711
711
left_pad = batch .attention_mask .shape [1 ] - batch .padding_right_offset - batch .max_sequence_length
712
712
if left_pad :
713
- # Trim attention mask and past kvs if we padded to multiple of 8. This is important to be able to
714
- # generate up to the model's token limit.
713
+ # Trim pre-allocated tensors if we padded to multiple of 8. This
714
+ # is important to be able to generate up to the model's token limit.
715
715
batch .attention_mask = batch .attention_mask [:, left_pad :]
716
+ batch .all_input_ids_tensor = batch .all_input_ids_tensor [:, left_pad :]
716
717
# For a combined KV cache, past is a list of Tensors, not Tuples
717
718
if torch .is_tensor (past [0 ]):
718
719
for cache in past :
You can’t perform that action at this time.
0 commit comments