Skip to content

Commit f3fc122

Browse files
tjohnson31415njhill
authored andcommitted
fix: fixes after testing causal_lm vectorization on GPU
Signed-off-by: Travis Johnson <[email protected]>
1 parent 2756820 commit f3fc122

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def from_pb(
162162

163163
# Padded all_input_ids_tensor; the maximum length of any sequence is the max
164164
# (padded) input sequence length + the max output length
165-
all_input_ids_tensor = torch.full(
165+
all_input_ids_tensor = all_input_ids.new_full(
166166
(batch_size, tokenize_length + padding_right_offset),
167167
tokenizer.pad_token_id,
168168
)
@@ -710,9 +710,10 @@ def generate_token(
710710
if first and not for_concat:
711711
left_pad = batch.attention_mask.shape[1] - batch.padding_right_offset - batch.max_sequence_length
712712
if left_pad:
713-
# Trim attention mask and past kvs if we padded to multiple of 8. This is important to be able to
714-
# generate up to the model's token limit.
713+
# Trim pre-allocated tensors if we padded to multiple of 8. This
714+
# is important to be able to generate up to the model's token limit.
715715
batch.attention_mask = batch.attention_mask[:, left_pad:]
716+
batch.all_input_ids_tensor = batch.all_input_ids_tensor[:, left_pad:]
716717
# For a combined KV cache, past is a list of Tensors, not Tuples
717718
if torch.is_tensor(past[0]):
718719
for cache in past:

0 commit comments

Comments
 (0)