Skip to content

Commit 2470e7d

Browse files
tjohnson31415njhill
authored andcommitted
fix: repetition penalty bug if EOS and PAD tokens have the same id
Since the decoding vectorization changes, the pad tokens are also passed in to the repetition penalty processor. In the case where the pad token id is equal to the EOS token id. This bug was found when testing with the `EleutherAI/gpt-neox-20b` model in TGIS. Having pad token id == eos token id does not seem to be that common, but it is also the fallback if the pad token cannot be found another way. There's also a little optimization change in this PR which is to pass a view over all_input_ids_tensor into `next_token_chooser` to avoid processing all of the pre-allocated output slots that have the pad token. Signed-off-by: Travis Johnson <[email protected]>
1 parent 85918c5 commit 2470e7d

File tree

6 files changed

+79
-5
lines changed

6 files changed

+79
-5
lines changed

integration_tests/test_cases_tinystarcoderpy.yaml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,54 @@
4747
stopReason: MAX_TOKENS
4848
text: "\nclass Shape(object):\n '''Shape class'''\n\n def __init__(self, x, y, z):"
4949

50+
# Regression test case for a bug that was found with the vectorization changes
51+
# If a model has eos_token_id == pad_token_id, we need to make sure that the
52+
# repetition penalty doesn't penalize the EOS token score just because the
53+
# all_input_ids_tensor has padding.
54+
# See also https://github.ibm.com/ai-foundation/fmaas-inference-server/pull/399
55+
#
56+
# First, see what the output would be with no padding in the request
57+
- name: Regression Test - don't penalize EOS because of PAD [1]
58+
request:
59+
params:
60+
method: GREEDY
61+
stopping:
62+
maxNewTokens: 10
63+
decoding:
64+
repetition_penalty: 100
65+
requests:
66+
- &hello_request {"text": "def print_hello():\n\t"}
67+
response:
68+
responses:
69+
- &hello_response
70+
generatedTokenCount: 8
71+
inputTokenCount: 6
72+
stopReason: EOS_TOKEN
73+
text: "\tprint(\"Hello World!\")\n"
74+
# we should get the same result with padding
75+
- name: Regression Test - don't penalize EOS because of PAD [2]
76+
request:
77+
params:
78+
method: GREEDY
79+
stopping:
80+
maxNewTokens: 10
81+
decoding:
82+
repetition_penalty: 100
83+
requests:
84+
- *hello_request
85+
# need two requests, since there is no padding with a one request batch...
86+
# the second request needs to be longer than the first and generate more
87+
# than one token as well so that the first is processed with padding
88+
- {"text": "# write a function that prints hello world"}
89+
response:
90+
responses:
91+
- *hello_response
92+
- generatedTokenCount: 10
93+
inputTokenCount: 8
94+
stopReason: MAX_TOKENS
95+
text: "\ndef print_hello():\n # create an"
96+
97+
5098
# Multiple inputs with token info
5199
- name: Multiple inputs with token info
52100
request:

server/text_generation_server/models/causal_lm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def from_pb(
116116
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
117117
pb=next_token_chooser_parameters,
118118
model_eos_token_id=getattr(tokenizer, 'model_eos_token_id', tokenizer.eos_token_id),
119+
model_pad_token_id=tokenizer.pad_token_id,
119120
return_logprobs=return_logprobs,
120121
dtype=dtype,
121122
device=device,
@@ -322,6 +323,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
322323
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
323324
pb=next_token_chooser_parameters,
324325
model_eos_token_id=batches[0].next_token_chooser.eos_token_id,
326+
model_pad_token_id=batches[0].next_token_chooser.pad_token_id,
325327
return_logprobs=ntc_return_logprobs,
326328
dtype=batches[0].next_token_chooser.dtype,
327329
device=batches[0].next_token_chooser.device,
@@ -638,7 +640,7 @@ def generate_token(
638640

639641
# Heterogeneous next token chooser expects last logits in the sequence
640642
next_input_ids, next_token_scores, next_token_logprobs = batch.next_token_chooser(
641-
input_ids=batch.all_input_ids_tensor, scores=logits[:, -1, :]
643+
input_ids=batch.all_input_ids_tensor[:, : -batch.padding_right_offset], scores=logits[:, -1, :]
642644
)
643645

644646
# Generated tokens

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def from_pb(
171171
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
172172
pb=next_token_chooser_parameters,
173173
model_eos_token_id=getattr(tokenizer, 'model_eos_token_id', tokenizer.eos_token_id),
174+
model_pad_token_id=tokenizer.pad_token_id,
174175
return_logprobs=return_logprobs,
175176
dtype=dtype,
176177
device=device,
@@ -258,6 +259,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
258259
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
259260
pb=next_token_chooser_parameters,
260261
model_eos_token_id=first_next_token_chooser.eos_token_id,
262+
model_pad_token_id=first_next_token_chooser.pad_token_id,
261263
return_logprobs=ntc_return_logprobs,
262264
dtype=first_next_token_chooser.dtype,
263265
device=first_next_token_chooser.device,
@@ -521,7 +523,7 @@ def _process_new_tokens(
521523
logits = out
522524

523525
next_token_ids, next_token_scores, next_token_logprobs = batch.next_token_chooser(
524-
input_ids=batch.all_input_ids_tensor, scores=logits,
526+
input_ids=batch.all_input_ids_tensor[:, :batch.max_seqlen], scores=logits,
525527
)
526528

527529
# add the next token ids to all_input_ids_tensor

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def from_pb(
216216
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
217217
pb=next_token_chooser_parameters,
218218
model_eos_token_id=getattr(tokenizer, 'model_eos_token_id', tokenizer.eos_token_id),
219+
model_pad_token_id=tokenizer.pad_token_id,
219220
return_logprobs=return_logprobs,
220221
dtype=dtype,
221222
device=device
@@ -432,6 +433,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
432433
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
433434
pb=next_token_chooser_parameters,
434435
model_eos_token_id=batches[0].next_token_chooser.eos_token_id,
436+
model_pad_token_id=batches[0].next_token_chooser.pad_token_id,
435437
return_logprobs=ntc_return_logprobs,
436438
dtype=batches[0].next_token_chooser.dtype,
437439
device=batches[0].next_token_chooser.device,
@@ -644,7 +646,7 @@ def generate_token(
644646
)
645647

646648
next_input_ids, next_token_scores, next_token_logprobs = batch.next_token_chooser(
647-
input_ids=batch.all_decoder_input_ids_tensor, scores=logits[:, -1, :]
649+
input_ids=batch.all_decoder_input_ids_tensor[:, : - batch.padding_right_offset], scores=logits[:, -1, :]
648650
)
649651

650652
# Generated tokens

server/text_generation_server/utils/logits_process.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,22 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
102102
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
103103
"""
104104

105-
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
105+
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device, id_to_exclude: Optional[int] = None):
106106
self.penalty = penalty
107107
self.penalty_tensor = torch.tensor(
108108
penalty, dtype=dtype, device=device
109109
).unsqueeze(1)
110+
self.id_to_exclude = id_to_exclude
110111

111112
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
113+
# as an optimization for a common case, we skip the exclusion if there is only
114+
# one request in the batch (assumes that there is no padding with a single request)
115+
do_exclude = self.id_to_exclude is not None and input_ids.shape[0] != 1
116+
# save out the original scores if we are excluding an id
117+
if do_exclude:
118+
# tensor is updated in-place, so need to clone here
119+
scores_of_id_to_exclude = scores[:, self.id_to_exclude].clone()
120+
112121
score = torch.gather(scores, 1, input_ids)
113122

114123
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
@@ -117,6 +126,11 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso
117126
)
118127

119128
scores.scatter_(1, input_ids, score)
129+
130+
# restore the scores for the "excluded" id
131+
if do_exclude:
132+
scores[:, self.id_to_exclude] = scores_of_id_to_exclude
133+
120134
return scores
121135

122136
def filter(self, indices):

server/text_generation_server/utils/tokens.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def __init__(
173173
min_new_tokens: List[int],
174174
return_logprobs: List[bool],
175175
eos_token_id: Optional[int] = None,
176+
pad_token_id: Optional[int] = None,
176177
device: Optional[torch.device] = None,
177178
dtype: torch.dtype = None,
178179
# allow passing in existing values to support combining HNTC instances
@@ -181,7 +182,9 @@ def __init__(
181182
warpers = []
182183
self.repetition_processor = (
183184
HeterogeneousRepetitionPenaltyLogitsProcessor(
184-
repetition_penalty, dtype, device
185+
repetition_penalty, dtype, device,
186+
# do not penalize the eos token if it is the same id as the pad token
187+
id_to_exclude = eos_token_id if eos_token_id == pad_token_id else None,
185188
)
186189
if any(x != 1.0 for x in repetition_penalty)
187190
else None
@@ -215,6 +218,7 @@ def __init__(
215218

216219
self.warpers = warpers
217220
self.eos_token_id = eos_token_id
221+
self.pad_token_id = pad_token_id
218222
self.length_penalty = length_penalty
219223
self.min_new_tokens = min_new_tokens
220224
self.current_tokens = current_tokens if current_tokens is not None else [0] * len(do_sample)
@@ -267,6 +271,7 @@ def from_pb(
267271
cls,
268272
pb: List[generate_pb2.NextTokenChooserParameters],
269273
model_eos_token_id: Optional[int],
274+
model_pad_token_id: Optional[int],
270275
return_logprobs: List[bool],
271276
dtype: torch.dtype,
272277
device: torch.device,
@@ -291,6 +296,7 @@ def from_pb(
291296
seeds=seeds,
292297
min_new_tokens=[pb_.min_new_tokens for pb_ in pb],
293298
eos_token_id=model_eos_token_id,
299+
pad_token_id=model_pad_token_id,
294300
return_logprobs=return_logprobs,
295301
device=device,
296302
dtype=dtype,

0 commit comments

Comments
 (0)