@@ -131,8 +131,10 @@ def from_pb(
131
131
# Tokenize batch
132
132
tokenize_length = max_input_length
133
133
# Pad to multiple of 8 for tensor core GPUs
134
+ left_pad = 0
134
135
if device .type == "cuda" and CUDA_PAD_TO_MULT_OF_8 and (mod := tokenize_length % 8 ) != 0 :
135
- tokenize_length += 8 - mod
136
+ left_pad = 8 - mod
137
+ tokenize_length += left_pad
136
138
tokenized_inputs = tokenizer (
137
139
input_texts ,
138
140
return_tensors = "pt" ,
@@ -163,10 +165,11 @@ def from_pb(
163
165
# Padded all_input_ids_tensor; the maximum length of any sequence is the max
164
166
# (padded) input sequence length + the max output length
165
167
all_input_ids_tensor = all_input_ids .new_full (
166
- (batch_size , tokenize_length + padding_right_offset ),
168
+ (batch_size , max_input_length + padding_right_offset ),
167
169
tokenizer .pad_token_id ,
168
170
)
169
- all_input_ids_tensor [:, :all_input_ids .shape [1 ]] = all_input_ids
171
+ no_pad_input_ids = all_input_ids [:, left_pad :] if left_pad else all_input_ids
172
+ all_input_ids_tensor [:, :no_pad_input_ids .shape [1 ]] = no_pad_input_ids
170
173
171
174
if prefix_ids :
172
175
# Get input embeddings
@@ -282,22 +285,19 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
282
285
)
283
286
284
287
# We need to slice the attention mask and all_input_ids_tensor to
285
- # remove padding from previous steps and to remove unused allocated
286
- # space
288
+ # remove padding from previous steps and to remove unused allocated space
287
289
left_offset = max_sequence_length - batch .max_sequence_length
288
- batch_left_offset = (
289
- batch .attention_mask .shape [1 ] - batch .max_sequence_length - batch .padding_right_offset
290
- )
290
+ batch_left_offset = - batch .max_sequence_length - batch .padding_right_offset
291
291
attention_mask [
292
292
start_index :end_index , left_offset :- padding_right_offset ,
293
293
] = batch .attention_mask [
294
- :, batch_left_offset : - batch .padding_right_offset ,
294
+ :, batch_left_offset : - batch .padding_right_offset ,
295
295
]
296
296
297
297
all_input_ids_tensor [
298
298
start_index :end_index , left_offset :- padding_right_offset ,
299
299
] = batch .all_input_ids_tensor [
300
- :, batch_left_offset : - batch .padding_right_offset ,
300
+ :, : - batch .padding_right_offset ,
301
301
]
302
302
303
303
if batch .position_ids is not None :
@@ -692,7 +692,6 @@ def generate_token(
692
692
693
693
except Exception as e :
694
694
logging .exception (f"token decoding error for request #{ request .id } " )
695
- next_token = all_input_ids .new_tensor ([self .tokenizer .pad_token_id ])
696
695
# Add to the errors to return
697
696
decode_errors .append (GenerateError (
698
697
request_id = request .id , message = f"Token decoding error: { str (e )} "
@@ -713,7 +712,6 @@ def generate_token(
713
712
# Trim pre-allocated tensors if we padded to multiple of 8. This
714
713
# is important to be able to generate up to the model's token limit.
715
714
batch .attention_mask = batch .attention_mask [:, left_pad :]
716
- batch .all_input_ids_tensor = batch .all_input_ids_tensor [:, left_pad :]
717
715
# For a combined KV cache, past is a list of Tensors, not Tuples
718
716
if torch .is_tensor (past [0 ]):
719
717
for cache in past :
0 commit comments