Skip to content

Commit de4bf3d

Browse files
authored
[Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708)
* Adapt repetition_penalty and no_repeat_ngram_size * fix no_repeat_ngram_size_logit_process * remove batch_updated * fix annotation * modified codes based on the review feedback. * rm get_batch_token_ids
1 parent 50104ab commit de4bf3d

File tree

5 files changed

+94
-18
lines changed

5 files changed

+94
-18
lines changed

colossalai/inference/batch_bucket.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ def use_spec_dec(self) -> bool:
102102
def num_tokens_to_verify(self) -> int:
103103
return self._num_tokens_to_verify
104104

105+
@property
106+
def batch_token_ids(self) -> List[List[int]]:
107+
out = []
108+
for seq in self.seqs_li:
109+
out.append(seq.input_token_id + seq.output_token_id)
110+
return out
111+
105112
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
106113
"""Set batch bucket to use speculatvie decoding.
107114
This will notify the adjust the lengths of inputs during modeling,
@@ -328,6 +335,7 @@ def pop_n_seqs(
328335
seqs.append(seq)
329336
if not self.is_compact:
330337
self._make_compact()
338+
331339
return seqs, block_tables
332340

333341
def pop_finished(
@@ -432,6 +440,7 @@ def merge(self, other: "BatchBucket") -> List[int]:
432440
block_tables = torch.stack(block_tables_li)
433441
self.add_seqs(seqs, alloc_block_tables=block_tables)
434442
unmerged_ids = other.seqs_ids
443+
435444
return unmerged_ids
436445

437446
########## The following methods are expected to be used in modeling ###########

colossalai/inference/config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ class InferenceConfig:
9999
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
100100
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
101101
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
102-
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
102+
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
103+
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
104+
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
103105
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
104106
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
105107
block_size (int): The number of blocks in a logical block, defaults to 16.
@@ -136,7 +138,9 @@ class InferenceConfig:
136138
early_stopping: Optional[bool] = False
137139
top_k: Optional[int] = None
138140
top_p: Optional[float] = None
139-
min_p: Optional[float] = None
141+
temperature: Optional[float] = 1.0
142+
no_repeat_ngram_size: Optional[int] = 0
143+
repetition_penalty: Optional[float] = 1.0
140144

141145
# speculative decoding configs
142146
max_n_spec_tokens: int = 5
@@ -213,7 +217,7 @@ def to_generation_config(self, model_config) -> GenerationConfig:
213217
"do_sample": self.do_sample,
214218
"num_beams": self.beam_width,
215219
}
216-
for type in ["top_k", "top_p", "min_p"]:
220+
for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]:
217221
if hasattr(self, type):
218222
meta_config[type] = getattr(self, type)
219223
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:

colossalai/inference/core/engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def steps_spec_dec(self) -> List[Sequence]:
424424

425425
# 2. Prefill main model (Verifier) - fill past kv cache for main model
426426
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
427-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
427+
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
428428
# append new inputs to the batch, temporarily
429429
batch.append_batch_tokens(next_tokens)
430430
self.request_handler.allocate_batch_spec_dec(batch, 1)
@@ -472,7 +472,7 @@ def steps_spec_dec(self) -> List[Sequence]:
472472
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
473473
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
474474

475-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
475+
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
476476

477477
# 5. Compare and process the results
478478
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
@@ -738,7 +738,7 @@ def step(self) -> List[str]:
738738
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
739739
if self.inference_config.pad_input:
740740
logits = logits[:, -1, :]
741-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
741+
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
742742
self.request_handler.append_next_tokens(next_tokens)
743743
finished_sequences = self.request_handler.update()
744744

colossalai/inference/core/request_handler.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
from colossalai.inference.logit_processors import logit_processor
1212
from colossalai.inference.sampler import *
1313
from colossalai.inference.struct import RequestStatus, Sequence
14-
from colossalai.logging import get_dist_logger
1514

1615
__all__ = ["RunningList", "RequestHandler"]
1716

18-
logger = get_dist_logger(__name__)
19-
2017

2118
class RunningList:
2219
"""
@@ -331,15 +328,21 @@ def check_unfinished_seqs(self) -> bool:
331328
def total_requests_in_batch_bucket(self) -> int:
332329
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
333330

334-
def search_tokens(self, generation_config: GenerationConfig, logits):
331+
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
335332
"""
336333
Sample tokens for finished requests.
337334
"""
338335

336+
# NOTE: need to decide the granularity to process logits (sequence or batch)
337+
config_dict = generation_config.to_dict()
338+
# process repetition_penalty, no_repeat_ngram_size
339+
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
340+
if type in config_dict and config_dict[type] is not None:
341+
logits = logit_processor(type, logits, config_dict[type], cur_batch)
342+
339343
# do logit processor
340344
if generation_config.do_sample:
341-
# NOTE: need to decide the granularity to process logits (sequence or batch)
342-
config_dict = generation_config.to_dict()
345+
# process temperature, top_k, top_p
343346
for type in ["temperature", "top_k", "top_p"]:
344347
if type in config_dict and config_dict[type] is not None:
345348
logits = logit_processor(type, logits, config_dict[type])

colossalai/inference/logit_processors.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
2+
13
import torch
24
import torch.nn.functional as F
35

6+
from colossalai.inference.batch_bucket import BatchBucket
7+
48
_LOGIT_PROCESSOR_MAP = {}
59

610

@@ -17,6 +21,66 @@ def register(func):
1721
return register
1822

1923

24+
@register_logit_processor("no_repeat_ngram_size")
25+
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
26+
"""
27+
enforces no repetition of n-grams to avoid repetitions of word sequences.
28+
"""
29+
30+
if not isinstance(ngram_size, int) or ngram_size < 0:
31+
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
32+
33+
if ngram_size != 0:
34+
batch_token_ids = batch.batch_token_ids
35+
batch_size = len(batch_token_ids)
36+
37+
for batch_id in range(batch_size):
38+
current_token_ids = batch_token_ids[batch_id]
39+
current_len = len(current_token_ids)
40+
if current_len + 1 < ngram_size:
41+
continue
42+
43+
ngrams_dict = {}
44+
45+
for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]):
46+
prev_ngram_tuple = tuple(ngram[:-1])
47+
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]
48+
49+
prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len])
50+
banned_token = ngrams_dict.get(prev_ngrams, [])
51+
52+
logits[batch_id, banned_token] = -float("inf")
53+
54+
return logits
55+
56+
57+
@register_logit_processor("repetition_penalty")
58+
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
59+
"""
60+
apply the penalty to the tokens present in the prompt.
61+
"""
62+
63+
if not isinstance(penalty, float) or not (penalty > 0):
64+
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
65+
66+
logit_list = []
67+
68+
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
69+
if penalty != 1.0:
70+
batch_token_ids = batch.batch_token_ids
71+
for batch_id in range(len(batch_token_ids)):
72+
current_logit = logits[batch_id]
73+
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)
74+
75+
curretn_socre = torch.gather(current_logit, 0, current_token)
76+
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
77+
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
78+
79+
logits = torch.stack(logit_list)
80+
81+
return logits
82+
83+
2084
@register_logit_processor("temperature")
2185
def temperature_logit_process(logits, temperature: float):
2286
"""
@@ -68,14 +132,13 @@ def top_p_logit_processor(logits, top_p: float):
68132
return logits
69133

70134

71-
def logit_processor(processor: str, logits, attrs):
135+
def logit_processor(processor: str, logits, *args, **kwargs):
72136
"""
73137
do logit process for given logits.
74138
75139
Args:
76140
processor(str): the type of logit processor
77141
logits(torch.Tensor): input logits
78-
attrs(dict): attrs of the logit processor
79142
80143
Returns:
81144
logits after process
@@ -84,8 +147,5 @@ def logit_processor(processor: str, logits, attrs):
84147
return logits
85148
else:
86149
func = _LOGIT_PROCESSOR_MAP[processor]
87-
try:
88-
logits = func(logits, attrs)
89-
except Exception:
90-
return logits
150+
logits = func(logits, *args, **kwargs)
91151
return logits

0 commit comments

Comments
 (0)