Skip to content

Commit 85918c5

Browse files
committed
Don't include extra left padding in all_input_ids_tensor
It's unnecessary here
1 parent 0ebc567 commit 85918c5

File tree

3 files changed

+10
-14
lines changed

3 files changed

+10
-14
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@ def from_pb(
131131
# Tokenize batch
132132
tokenize_length = max_input_length
133133
# Pad to multiple of 8 for tensor core GPUs
134+
left_pad = 0
134135
if device.type == "cuda" and CUDA_PAD_TO_MULT_OF_8 and (mod := tokenize_length % 8) != 0:
135-
tokenize_length += 8 - mod
136+
left_pad = 8 - mod
137+
tokenize_length += left_pad
136138
tokenized_inputs = tokenizer(
137139
input_texts,
138140
return_tensors="pt",
@@ -163,10 +165,11 @@ def from_pb(
163165
# Padded all_input_ids_tensor; the maximum length of any sequence is the max
164166
# (padded) input sequence length + the max output length
165167
all_input_ids_tensor = all_input_ids.new_full(
166-
(batch_size, tokenize_length + padding_right_offset),
168+
(batch_size, max_input_length + padding_right_offset),
167169
tokenizer.pad_token_id,
168170
)
169-
all_input_ids_tensor[:, :all_input_ids.shape[1]] = all_input_ids
171+
no_pad_input_ids = all_input_ids[:, left_pad:] if left_pad else all_input_ids
172+
all_input_ids_tensor[:, :no_pad_input_ids.shape[1]] = no_pad_input_ids
170173

171174
if prefix_ids:
172175
# Get input embeddings
@@ -282,22 +285,19 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
282285
)
283286

284287
# We need to slice the attention mask and all_input_ids_tensor to
285-
# remove padding from previous steps and to remove unused allocated
286-
# space
288+
# remove padding from previous steps and to remove unused allocated space
287289
left_offset = max_sequence_length - batch.max_sequence_length
288-
batch_left_offset = (
289-
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
290-
)
290+
batch_left_offset = -batch.max_sequence_length - batch.padding_right_offset
291291
attention_mask[
292292
start_index:end_index, left_offset:-padding_right_offset,
293293
] = batch.attention_mask[
294-
:, batch_left_offset : -batch.padding_right_offset,
294+
:, batch_left_offset: -batch.padding_right_offset,
295295
]
296296

297297
all_input_ids_tensor[
298298
start_index:end_index, left_offset:-padding_right_offset,
299299
] = batch.all_input_ids_tensor[
300-
:, batch_left_offset : -batch.padding_right_offset,
300+
:, :-batch.padding_right_offset,
301301
]
302302

303303
if batch.position_ids is not None:
@@ -692,7 +692,6 @@ def generate_token(
692692

693693
except Exception as e:
694694
logging.exception(f"token decoding error for request #{request.id}")
695-
next_token = all_input_ids.new_tensor([self.tokenizer.pad_token_id])
696695
# Add to the errors to return
697696
decode_errors.append(GenerateError(
698697
request_id=request.id, message=f"Token decoding error: {str(e)}"
@@ -713,7 +712,6 @@ def generate_token(
713712
# Trim pre-allocated tensors if we padded to multiple of 8. This
714713
# is important to be able to generate up to the model's token limit.
715714
batch.attention_mask = batch.attention_mask[:, left_pad:]
716-
batch.all_input_ids_tensor = batch.all_input_ids_tensor[:, left_pad:]
717715
# For a combined KV cache, past is a list of Tensors, not Tuples
718716
if torch.is_tensor(past[0]):
719717
for cache in past:

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,6 @@ def _process_new_tokens(
577577

578578
except Exception as e:
579579
logging.exception(f"token decoding error for request #{request.id}")
580-
next_token = all_input_ids.new_tensor([self.tokenizer.pad_token_id])
581580
# Add to the errors to return
582581
decode_errors.append(GenerateError(
583582
request_id=request.id, message=f"Token decoding error: {str(e)}"

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,6 @@ def generate_token(
701701

702702
except Exception as e:
703703
logging.exception(f"token decoding error for request #{request.id}")
704-
next_token = batch.all_decoder_input_ids_tensor.new_tensor([self.tokenizer.pad_token_id])
705704
# Add to the errors to return
706705
decode_errors.append(GenerateError(
707706
request_id=request.id, message=f"Token decoding error: {str(e)}"

0 commit comments

Comments
 (0)