Skip to content

Commit 7ad221f

Browse files
tjohnson31415njhill
authored andcommitted
feat: use HeterogeneousNextTokenChooser in seq2seq_lm
Signed-off-by: Travis Johnson <[email protected]>
1 parent f4060c0 commit 7ad221f

File tree

1 file changed

+91
-34
lines changed

1 file changed

+91
-34
lines changed

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 91 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from text_generation_server.prompt_cache import PrefixCache
1717
from text_generation_server.utils.hub import get_model_path
1818
from text_generation_server.utils.token_types import TokenInfo, InputTokens
19-
from text_generation_server.utils.tokens import NextTokenChooser, get_token_info, NONES
19+
from text_generation_server.utils.tokens import HeterogeneousNextTokenChooser, get_token_info, NONES
2020
from text_generation_server.inference_engine import get_inference_engine_class
2121

2222

@@ -37,7 +37,7 @@ class Seq2SeqLMBatch(Batch):
3737
encoder_last_hidden_state: Optional[torch.Tensor]
3838

3939
# All tokens
40-
all_decoder_input_ids: List[torch.Tensor]
40+
all_decoder_input_ids_tensor: torch.Tensor
4141

4242
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
4343
past_key_values: Optional[List[Tuple]]
@@ -47,13 +47,14 @@ class Seq2SeqLMBatch(Batch):
4747
decoder_input_lengths: List[int]
4848

4949
# Generation helpers
50-
next_token_choosers: List[NextTokenChooser]
50+
next_token_chooser: HeterogeneousNextTokenChooser
5151

5252
# Metadata used for padding
5353
max_input_length: int
5454
max_decoder_input_length: int
5555
padding_right_offset: int
5656
max_remaining_tokens: List[int]
57+
pad_token_id: int
5758

5859
# Past metadata
5960
keys_head_dim_last: bool = True
@@ -77,7 +78,8 @@ def from_pb(
7778
) -> Tuple[Optional["Seq2SeqLMBatch"], List[GenerateError]]:
7879
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
7980
input_texts = []
80-
next_token_choosers = []
81+
next_token_chooser_parameters = []
82+
return_logprobs = []
8183
input_lengths = []
8284
decoder_input_lengths = []
8385
max_remaining_tokens = []
@@ -123,9 +125,8 @@ def from_pb(
123125
max_input_length = max(max_input_length, input_length)
124126
max_remaining_tokens.append(max_output_length)
125127
padding_right_offset = max(padding_right_offset, max_output_length)
126-
next_token_choosers.append(NextTokenChooser.from_pb(
127-
r.parameters, r.details.logprobs, tokenizer, device,
128-
))
128+
next_token_chooser_parameters.append(r.parameters)
129+
return_logprobs.append(r.details.logprobs)
129130
i += 1
130131

131132
if errors:
@@ -176,6 +177,12 @@ def from_pb(
176177
else:
177178
inputs_embeds = None
178179

180+
# Allocate maximal decoder_all_input_ids_tensor
181+
all_decoder_input_ids_tensor = torch.full(
182+
(batch_size, max_decoder_input_length + padding_right_offset),
183+
tokenizer.pad_token_id,
184+
dtype=torch.int64, device=device,
185+
)
179186
if decoder_prefix_ids:
180187
# Construct decoder embeddings and attention mask
181188
start_tok_embedding = prefix_cache.decoder_start_tok_embedding
@@ -187,6 +194,7 @@ def from_pb(
187194
(batch_size, max_decoder_input_length + padding_right_offset)
188195
)
189196
decoder_attention_mask[:, -1-padding_right_offset] = 1
197+
all_decoder_input_ids_tensor[:, -1-padding_right_offset] = tokenizer.bos_token_id
190198

191199
for i, dp in decoder_prefix_ids.items():
192200
# Update decoder embedding and attention mask
@@ -203,6 +211,15 @@ def from_pb(
203211
# Each decoder sequence only contains the bos_token
204212
# so decoder_input_ids is a torch tensor of size [batch_size, 1]
205213
decoder_input_ids = input_ids.new_full((batch_size, 1), tokenizer.bos_token_id)
214+
all_decoder_input_ids_tensor[:, 0] = tokenizer.bos_token_id
215+
216+
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
217+
pb=next_token_chooser_parameters,
218+
model_eos_token_id=getattr(tokenizer, 'model_eos_token_id', tokenizer.eos_token_id),
219+
return_logprobs=return_logprobs,
220+
dtype=dtype,
221+
device=device
222+
)
206223

207224
return cls(
208225
batch_id=pb.id,
@@ -213,16 +230,17 @@ def from_pb(
213230
decoder_input_ids=decoder_input_ids,
214231
decoder_inputs_embeds=decoder_inputs_embeds,
215232
decoder_attention_mask=decoder_attention_mask,
216-
all_decoder_input_ids=list(decoder_input_ids),
233+
all_decoder_input_ids_tensor=all_decoder_input_ids_tensor,
217234
encoder_last_hidden_state=None,
218235
past_key_values=None,
219236
input_lengths=input_lengths,
220237
decoder_input_lengths=decoder_input_lengths,
221238
max_remaining_tokens=max_remaining_tokens,
222-
next_token_choosers=next_token_choosers,
239+
next_token_chooser=next_token_chooser,
223240
max_input_length=max_input_length,
224241
max_decoder_input_length=max_decoder_input_length,
225242
padding_right_offset=padding_right_offset,
243+
pad_token_id=tokenizer.pad_token_id,
226244
), errors
227245

228246
@classmethod
@@ -247,12 +265,15 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
247265
input_lengths = []
248266
decoder_input_lengths = []
249267
max_remaining_tokens = []
250-
next_token_choosers = []
251-
all_decoder_input_ids = []
268+
next_token_chooser_parameters = []
269+
ntc_current_tokens = []
270+
ntc_samplings = []
271+
ntc_return_logprobs = []
252272

253273
# Batch tensors
254274
attention_mask = None
255275
decoder_input_ids = None
276+
all_decoder_input_ids_tensor = None
256277
decoder_attention_mask = None
257278
encoder_last_hidden_state = None
258279
past_key_values = []
@@ -267,8 +288,11 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
267288
input_lengths.extend(batch.input_lengths)
268289
decoder_input_lengths.extend(batch.decoder_input_lengths)
269290
max_remaining_tokens.extend(batch.max_remaining_tokens)
270-
next_token_choosers.extend(batch.next_token_choosers)
271-
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
291+
292+
next_token_chooser_parameters.extend(r.parameters for r in batch.requests)
293+
ntc_current_tokens.extend(batch.next_token_chooser.current_tokens)
294+
ntc_samplings.extend(batch.next_token_chooser.samplings)
295+
ntc_return_logprobs.extend(batch.next_token_chooser.return_logprobs)
272296

273297
# Slicing end index for this batch
274298
end_index = start_index + len(batch)
@@ -293,6 +317,20 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
293317
# Copy to correct indices
294318
decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
295319

320+
# Create padded tensor
321+
if all_decoder_input_ids_tensor is None:
322+
all_decoder_input_ids_tensor = batches[0].all_decoder_input_ids_tensor.new_full(
323+
(total_batch_size, max_decoder_input_length + padding_right_offset),
324+
batches[0].pad_token_id,
325+
)
326+
# Copy to correct sub-tensor
327+
rhs_pad_diff = padding_right_offset - batch.padding_right_offset
328+
all_decoder_input_ids_tensor[
329+
start_index:end_index,
330+
-(batch.all_decoder_input_ids_tensor.shape[1] + rhs_pad_diff)
331+
:(all_decoder_input_ids_tensor.shape[1] - rhs_pad_diff)
332+
] = batch.all_decoder_input_ids_tensor
333+
296334
# Create padded tensor
297335
if decoder_attention_mask is None:
298336
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
@@ -391,6 +429,16 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
391429

392430
start_index = end_index
393431

432+
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
433+
pb=next_token_chooser_parameters,
434+
model_eos_token_id=batches[0].next_token_chooser.eos_token_id,
435+
return_logprobs=ntc_return_logprobs,
436+
dtype=batches[0].next_token_chooser.dtype,
437+
device=batches[0].next_token_chooser.device,
438+
samplings=ntc_samplings,
439+
current_tokens=ntc_current_tokens,
440+
)
441+
394442
return cls(
395443
batch_id=batches[0].batch_id,
396444
requests=requests,
@@ -401,15 +449,16 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
401449
decoder_inputs_embeds=None,
402450
decoder_attention_mask=decoder_attention_mask,
403451
encoder_last_hidden_state=encoder_last_hidden_state,
404-
all_decoder_input_ids=all_decoder_input_ids,
452+
all_decoder_input_ids_tensor=all_decoder_input_ids_tensor,
405453
past_key_values=past_key_values,
406454
input_lengths=input_lengths,
407455
decoder_input_lengths=decoder_input_lengths,
408456
max_remaining_tokens=max_remaining_tokens,
409-
next_token_choosers=next_token_choosers,
457+
next_token_chooser=next_token_chooser,
410458
max_input_length=max_input_length,
411459
max_decoder_input_length=max_decoder_input_length,
412460
padding_right_offset=padding_right_offset,
461+
pad_token_id=batches[0].pad_token_id,
413462
keys_head_dim_last=batches[0].keys_head_dim_last,
414463
)
415464

@@ -435,8 +484,7 @@ def prune(cls, batch: "Seq2SeqLMBatch", completed_ids: List[int]) -> Optional["S
435484
batch.decoder_input_lengths = list(slice_list(batch.decoder_input_lengths))
436485
batch.max_remaining_tokens = list(slice_list(batch.max_remaining_tokens))
437486
batch.requests = slice_list(batch.requests)
438-
batch.next_token_choosers = slice_list(batch.next_token_choosers)
439-
batch.all_decoder_input_ids = list(slice_list(batch.all_decoder_input_ids))
487+
batch.next_token_chooser = batch.next_token_chooser.filter(keep_indices)
440488

441489
batch.max_input_length = max(batch.input_lengths)
442490
batch.max_decoder_input_length = max(batch.decoder_input_lengths)
@@ -456,6 +504,12 @@ def prune(cls, batch: "Seq2SeqLMBatch", completed_ids: List[int]) -> Optional["S
456504
batch.attention_mask = batch.attention_mask[keep_indices, -batch.max_input_length:]
457505
batch.decoder_input_ids = batch.decoder_input_ids[keep_indices]
458506

507+
batch.all_decoder_input_ids_tensor = batch.all_decoder_input_ids_tensor[
508+
keep_indices,
509+
-(batch.padding_right_offset + batch.max_decoder_input_length)
510+
:(batch.all_decoder_input_ids_tensor.shape[1] - batch.padding_right_offset) + new_padding_right_offset,
511+
]
512+
459513
# Ensure that past_key_values tensors can be updated in-place
460514
if type(batch.past_key_values[0]) == tuple:
461515
batch.past_key_values = [list(layer) for layer in batch.past_key_values]
@@ -589,8 +643,9 @@ def generate_token(
589643
batch.decoder_inputs_embeds,
590644
)
591645

592-
# New values for next forward
593-
next_batch_decoder_input_ids = []
646+
next_input_ids, next_token_scores, next_token_logprobs = batch.next_token_chooser(
647+
input_ids=batch.all_decoder_input_ids_tensor, scores=logits[:, -1, :]
648+
)
594649

595650
# Generated tokens
596651
generated_tokens: List[TokenInfo] = []
@@ -601,8 +656,9 @@ def generate_token(
601656
iterator = zip(
602657
batch.requests,
603658
logits,
604-
batch.next_token_choosers,
605-
batch.all_decoder_input_ids,
659+
next_input_ids,
660+
next_token_scores,
661+
next_token_logprobs,
606662
batch.decoder_input_lengths,
607663
batch.input_ids if first else NONES,
608664
batch.input_lengths if first else NONES,
@@ -612,20 +668,20 @@ def generate_token(
612668
for i, (
613669
request,
614670
logits,
615-
next_token_chooser,
616-
all_decoder_input_ids,
671+
next_token,
672+
scores,
673+
logprobs,
617674
decoder_input_length,
618675
input_ids,
619676
input_length,
620677
) in enumerate(iterator):
621678
try:
622-
# Select next token
623-
next_token, scores, logprobs = next_token_chooser(
624-
all_decoder_input_ids[None, -decoder_input_length:], logits
625-
)
679+
# Ensure tok view is 1st order, everything else is second.
680+
tok_view = next_token.view(-1)
681+
scores_view = scores.view(-1, scores.shape[-1])
682+
logprobs_view = logprobs.view(-1, logprobs.shape[-1]) if request.details.logprobs else None
626683

627-
# Return latest token
628-
token_info = get_token_info(request, scores, next_token, logprobs)
684+
token_info = get_token_info(request, scores_view, tok_view, logprobs_view)
629685

630686
# Return input tokens if requested
631687
# Note this could all be handled in the router for seq2seq models
@@ -645,25 +701,26 @@ def generate_token(
645701

646702
except Exception as e:
647703
logging.exception(f"token decoding error for request #{request.id}")
648-
next_token = all_decoder_input_ids.new_tensor([self.tokenizer.pad_token_id])
704+
next_token = batch.all_decoder_input_ids_tensor.new_tensor([self.tokenizer.pad_token_id])
649705
# Add to the errors to return
650706
decode_errors.append(GenerateError(
651707
request_id=request.id, message=f"Token decoding error: {str(e)}"
652708
))
653709

654-
# Append next token to decoder tokens
655-
next_batch_decoder_input_ids.append(next_token)
656-
batch.all_decoder_input_ids[i] = torch.cat([all_decoder_input_ids, next_token])
710+
# Adjust input/output counters
657711
batch.decoder_input_lengths[i] += 1
658712
batch.max_remaining_tokens[i] -= 1
659713

714+
# Add generated tokens to the input_ids_tensor
715+
batch.all_decoder_input_ids_tensor[:, -batch.padding_right_offset] = next_input_ids
716+
660717
# Update decoder_attention_mask as we added a new token to input_ids
661718
if batch.decoder_attention_mask is not None:
662719
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
663720

664721
batch.input_ids = None
665722
batch.inputs_embeds = None
666-
batch.decoder_input_ids = torch.cat(next_batch_decoder_input_ids).view(-1, 1)
723+
batch.decoder_input_ids = next_input_ids.view(-1, 1)
667724
batch.decoder_inputs_embeds = None
668725
batch.encoder_last_hidden_state = encoder_last_hidden_state
669726
batch.past_key_values = past

0 commit comments

Comments
 (0)