16
16
from text_generation_server .prompt_cache import PrefixCache
17
17
from text_generation_server .utils .hub import get_model_path
18
18
from text_generation_server .utils .token_types import TokenInfo , InputTokens
19
- from text_generation_server .utils .tokens import NextTokenChooser , get_token_info , NONES
19
+ from text_generation_server .utils .tokens import HeterogeneousNextTokenChooser , get_token_info , NONES
20
20
from text_generation_server .inference_engine import get_inference_engine_class
21
21
22
22
@@ -37,7 +37,7 @@ class Seq2SeqLMBatch(Batch):
37
37
encoder_last_hidden_state : Optional [torch .Tensor ]
38
38
39
39
# All tokens
40
- all_decoder_input_ids : List [ torch .Tensor ]
40
+ all_decoder_input_ids_tensor : torch .Tensor
41
41
42
42
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
43
43
past_key_values : Optional [List [Tuple ]]
@@ -47,13 +47,14 @@ class Seq2SeqLMBatch(Batch):
47
47
decoder_input_lengths : List [int ]
48
48
49
49
# Generation helpers
50
- next_token_choosers : List [ NextTokenChooser ]
50
+ next_token_chooser : HeterogeneousNextTokenChooser
51
51
52
52
# Metadata used for padding
53
53
max_input_length : int
54
54
max_decoder_input_length : int
55
55
padding_right_offset : int
56
56
max_remaining_tokens : List [int ]
57
+ pad_token_id : int
57
58
58
59
# Past metadata
59
60
keys_head_dim_last : bool = True
@@ -77,7 +78,8 @@ def from_pb(
77
78
) -> Tuple [Optional ["Seq2SeqLMBatch" ], List [GenerateError ]]:
78
79
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
79
80
input_texts = []
80
- next_token_choosers = []
81
+ next_token_chooser_parameters = []
82
+ return_logprobs = []
81
83
input_lengths = []
82
84
decoder_input_lengths = []
83
85
max_remaining_tokens = []
@@ -123,9 +125,8 @@ def from_pb(
123
125
max_input_length = max (max_input_length , input_length )
124
126
max_remaining_tokens .append (max_output_length )
125
127
padding_right_offset = max (padding_right_offset , max_output_length )
126
- next_token_choosers .append (NextTokenChooser .from_pb (
127
- r .parameters , r .details .logprobs , tokenizer , device ,
128
- ))
128
+ next_token_chooser_parameters .append (r .parameters )
129
+ return_logprobs .append (r .details .logprobs )
129
130
i += 1
130
131
131
132
if errors :
@@ -176,6 +177,12 @@ def from_pb(
176
177
else :
177
178
inputs_embeds = None
178
179
180
+ # Allocate maximal decoder_all_input_ids_tensor
181
+ all_decoder_input_ids_tensor = torch .full (
182
+ (batch_size , max_decoder_input_length + padding_right_offset ),
183
+ tokenizer .pad_token_id ,
184
+ dtype = torch .int64 , device = device ,
185
+ )
179
186
if decoder_prefix_ids :
180
187
# Construct decoder embeddings and attention mask
181
188
start_tok_embedding = prefix_cache .decoder_start_tok_embedding
@@ -187,6 +194,7 @@ def from_pb(
187
194
(batch_size , max_decoder_input_length + padding_right_offset )
188
195
)
189
196
decoder_attention_mask [:, - 1 - padding_right_offset ] = 1
197
+ all_decoder_input_ids_tensor [:, - 1 - padding_right_offset ] = tokenizer .bos_token_id
190
198
191
199
for i , dp in decoder_prefix_ids .items ():
192
200
# Update decoder embedding and attention mask
@@ -203,6 +211,15 @@ def from_pb(
203
211
# Each decoder sequence only contains the bos_token
204
212
# so decoder_input_ids is a torch tensor of size [batch_size, 1]
205
213
decoder_input_ids = input_ids .new_full ((batch_size , 1 ), tokenizer .bos_token_id )
214
+ all_decoder_input_ids_tensor [:, 0 ] = tokenizer .bos_token_id
215
+
216
+ next_token_chooser = HeterogeneousNextTokenChooser .from_pb (
217
+ pb = next_token_chooser_parameters ,
218
+ model_eos_token_id = getattr (tokenizer , 'model_eos_token_id' , tokenizer .eos_token_id ),
219
+ return_logprobs = return_logprobs ,
220
+ dtype = dtype ,
221
+ device = device
222
+ )
206
223
207
224
return cls (
208
225
batch_id = pb .id ,
@@ -213,16 +230,17 @@ def from_pb(
213
230
decoder_input_ids = decoder_input_ids ,
214
231
decoder_inputs_embeds = decoder_inputs_embeds ,
215
232
decoder_attention_mask = decoder_attention_mask ,
216
- all_decoder_input_ids = list ( decoder_input_ids ) ,
233
+ all_decoder_input_ids_tensor = all_decoder_input_ids_tensor ,
217
234
encoder_last_hidden_state = None ,
218
235
past_key_values = None ,
219
236
input_lengths = input_lengths ,
220
237
decoder_input_lengths = decoder_input_lengths ,
221
238
max_remaining_tokens = max_remaining_tokens ,
222
- next_token_choosers = next_token_choosers ,
239
+ next_token_chooser = next_token_chooser ,
223
240
max_input_length = max_input_length ,
224
241
max_decoder_input_length = max_decoder_input_length ,
225
242
padding_right_offset = padding_right_offset ,
243
+ pad_token_id = tokenizer .pad_token_id ,
226
244
), errors
227
245
228
246
@classmethod
@@ -247,12 +265,15 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
247
265
input_lengths = []
248
266
decoder_input_lengths = []
249
267
max_remaining_tokens = []
250
- next_token_choosers = []
251
- all_decoder_input_ids = []
268
+ next_token_chooser_parameters = []
269
+ ntc_current_tokens = []
270
+ ntc_samplings = []
271
+ ntc_return_logprobs = []
252
272
253
273
# Batch tensors
254
274
attention_mask = None
255
275
decoder_input_ids = None
276
+ all_decoder_input_ids_tensor = None
256
277
decoder_attention_mask = None
257
278
encoder_last_hidden_state = None
258
279
past_key_values = []
@@ -267,8 +288,11 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
267
288
input_lengths .extend (batch .input_lengths )
268
289
decoder_input_lengths .extend (batch .decoder_input_lengths )
269
290
max_remaining_tokens .extend (batch .max_remaining_tokens )
270
- next_token_choosers .extend (batch .next_token_choosers )
271
- all_decoder_input_ids .extend (batch .all_decoder_input_ids )
291
+
292
+ next_token_chooser_parameters .extend (r .parameters for r in batch .requests )
293
+ ntc_current_tokens .extend (batch .next_token_chooser .current_tokens )
294
+ ntc_samplings .extend (batch .next_token_chooser .samplings )
295
+ ntc_return_logprobs .extend (batch .next_token_chooser .return_logprobs )
272
296
273
297
# Slicing end index for this batch
274
298
end_index = start_index + len (batch )
@@ -293,6 +317,20 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
293
317
# Copy to correct indices
294
318
decoder_input_ids [start_index :end_index ] = batch .decoder_input_ids
295
319
320
+ # Create padded tensor
321
+ if all_decoder_input_ids_tensor is None :
322
+ all_decoder_input_ids_tensor = batches [0 ].all_decoder_input_ids_tensor .new_full (
323
+ (total_batch_size , max_decoder_input_length + padding_right_offset ),
324
+ batches [0 ].pad_token_id ,
325
+ )
326
+ # Copy to correct sub-tensor
327
+ rhs_pad_diff = padding_right_offset - batch .padding_right_offset
328
+ all_decoder_input_ids_tensor [
329
+ start_index :end_index ,
330
+ - (batch .all_decoder_input_ids_tensor .shape [1 ] + rhs_pad_diff )
331
+ :(all_decoder_input_ids_tensor .shape [1 ] - rhs_pad_diff )
332
+ ] = batch .all_decoder_input_ids_tensor
333
+
296
334
# Create padded tensor
297
335
if decoder_attention_mask is None :
298
336
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
@@ -391,6 +429,16 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
391
429
392
430
start_index = end_index
393
431
432
+ next_token_chooser = HeterogeneousNextTokenChooser .from_pb (
433
+ pb = next_token_chooser_parameters ,
434
+ model_eos_token_id = batches [0 ].next_token_chooser .eos_token_id ,
435
+ return_logprobs = ntc_return_logprobs ,
436
+ dtype = batches [0 ].next_token_chooser .dtype ,
437
+ device = batches [0 ].next_token_chooser .device ,
438
+ samplings = ntc_samplings ,
439
+ current_tokens = ntc_current_tokens ,
440
+ )
441
+
394
442
return cls (
395
443
batch_id = batches [0 ].batch_id ,
396
444
requests = requests ,
@@ -401,15 +449,16 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
401
449
decoder_inputs_embeds = None ,
402
450
decoder_attention_mask = decoder_attention_mask ,
403
451
encoder_last_hidden_state = encoder_last_hidden_state ,
404
- all_decoder_input_ids = all_decoder_input_ids ,
452
+ all_decoder_input_ids_tensor = all_decoder_input_ids_tensor ,
405
453
past_key_values = past_key_values ,
406
454
input_lengths = input_lengths ,
407
455
decoder_input_lengths = decoder_input_lengths ,
408
456
max_remaining_tokens = max_remaining_tokens ,
409
- next_token_choosers = next_token_choosers ,
457
+ next_token_chooser = next_token_chooser ,
410
458
max_input_length = max_input_length ,
411
459
max_decoder_input_length = max_decoder_input_length ,
412
460
padding_right_offset = padding_right_offset ,
461
+ pad_token_id = batches [0 ].pad_token_id ,
413
462
keys_head_dim_last = batches [0 ].keys_head_dim_last ,
414
463
)
415
464
@@ -435,8 +484,7 @@ def prune(cls, batch: "Seq2SeqLMBatch", completed_ids: List[int]) -> Optional["S
435
484
batch .decoder_input_lengths = list (slice_list (batch .decoder_input_lengths ))
436
485
batch .max_remaining_tokens = list (slice_list (batch .max_remaining_tokens ))
437
486
batch .requests = slice_list (batch .requests )
438
- batch .next_token_choosers = slice_list (batch .next_token_choosers )
439
- batch .all_decoder_input_ids = list (slice_list (batch .all_decoder_input_ids ))
487
+ batch .next_token_chooser = batch .next_token_chooser .filter (keep_indices )
440
488
441
489
batch .max_input_length = max (batch .input_lengths )
442
490
batch .max_decoder_input_length = max (batch .decoder_input_lengths )
@@ -456,6 +504,12 @@ def prune(cls, batch: "Seq2SeqLMBatch", completed_ids: List[int]) -> Optional["S
456
504
batch .attention_mask = batch .attention_mask [keep_indices , - batch .max_input_length :]
457
505
batch .decoder_input_ids = batch .decoder_input_ids [keep_indices ]
458
506
507
+ batch .all_decoder_input_ids_tensor = batch .all_decoder_input_ids_tensor [
508
+ keep_indices ,
509
+ - (batch .padding_right_offset + batch .max_decoder_input_length )
510
+ :(batch .all_decoder_input_ids_tensor .shape [1 ] - batch .padding_right_offset ) + new_padding_right_offset ,
511
+ ]
512
+
459
513
# Ensure that past_key_values tensors can be updated in-place
460
514
if type (batch .past_key_values [0 ]) == tuple :
461
515
batch .past_key_values = [list (layer ) for layer in batch .past_key_values ]
@@ -589,8 +643,9 @@ def generate_token(
589
643
batch .decoder_inputs_embeds ,
590
644
)
591
645
592
- # New values for next forward
593
- next_batch_decoder_input_ids = []
646
+ next_input_ids , next_token_scores , next_token_logprobs = batch .next_token_chooser (
647
+ input_ids = batch .all_decoder_input_ids_tensor , scores = logits [:, - 1 , :]
648
+ )
594
649
595
650
# Generated tokens
596
651
generated_tokens : List [TokenInfo ] = []
@@ -601,8 +656,9 @@ def generate_token(
601
656
iterator = zip (
602
657
batch .requests ,
603
658
logits ,
604
- batch .next_token_choosers ,
605
- batch .all_decoder_input_ids ,
659
+ next_input_ids ,
660
+ next_token_scores ,
661
+ next_token_logprobs ,
606
662
batch .decoder_input_lengths ,
607
663
batch .input_ids if first else NONES ,
608
664
batch .input_lengths if first else NONES ,
@@ -612,20 +668,20 @@ def generate_token(
612
668
for i , (
613
669
request ,
614
670
logits ,
615
- next_token_chooser ,
616
- all_decoder_input_ids ,
671
+ next_token ,
672
+ scores ,
673
+ logprobs ,
617
674
decoder_input_length ,
618
675
input_ids ,
619
676
input_length ,
620
677
) in enumerate (iterator ):
621
678
try :
622
- # Select next token
623
- next_token , scores , logprobs = next_token_chooser (
624
- all_decoder_input_ids [ None , - decoder_input_length :], logits
625
- )
679
+ # Ensure tok view is 1st order, everything else is second.
680
+ tok_view = next_token . view ( - 1 )
681
+ scores_view = scores . view ( - 1 , scores . shape [ - 1 ])
682
+ logprobs_view = logprobs . view ( - 1 , logprobs . shape [ - 1 ]) if request . details . logprobs else None
626
683
627
- # Return latest token
628
- token_info = get_token_info (request , scores , next_token , logprobs )
684
+ token_info = get_token_info (request , scores_view , tok_view , logprobs_view )
629
685
630
686
# Return input tokens if requested
631
687
# Note this could all be handled in the router for seq2seq models
@@ -645,25 +701,26 @@ def generate_token(
645
701
646
702
except Exception as e :
647
703
logging .exception (f"token decoding error for request #{ request .id } " )
648
- next_token = all_decoder_input_ids .new_tensor ([self .tokenizer .pad_token_id ])
704
+ next_token = batch . all_decoder_input_ids_tensor .new_tensor ([self .tokenizer .pad_token_id ])
649
705
# Add to the errors to return
650
706
decode_errors .append (GenerateError (
651
707
request_id = request .id , message = f"Token decoding error: { str (e )} "
652
708
))
653
709
654
- # Append next token to decoder tokens
655
- next_batch_decoder_input_ids .append (next_token )
656
- batch .all_decoder_input_ids [i ] = torch .cat ([all_decoder_input_ids , next_token ])
710
+ # Adjust input/output counters
657
711
batch .decoder_input_lengths [i ] += 1
658
712
batch .max_remaining_tokens [i ] -= 1
659
713
714
+ # Add generated tokens to the input_ids_tensor
715
+ batch .all_decoder_input_ids_tensor [:, - batch .padding_right_offset ] = next_input_ids
716
+
660
717
# Update decoder_attention_mask as we added a new token to input_ids
661
718
if batch .decoder_attention_mask is not None :
662
719
batch .decoder_attention_mask [:, - batch .padding_right_offset ] = 1
663
720
664
721
batch .input_ids = None
665
722
batch .inputs_embeds = None
666
- batch .decoder_input_ids = torch . cat ( next_batch_decoder_input_ids ) .view (- 1 , 1 )
723
+ batch .decoder_input_ids = next_input_ids .view (- 1 , 1 )
667
724
batch .decoder_inputs_embeds = None
668
725
batch .encoder_last_hidden_state = encoder_last_hidden_state
669
726
batch .past_key_values = past
0 commit comments