Skip to content

Commit 283c407

Browse files
[Inference] Fix Inference Generation Config and Sampling (#5710)
* refactor and add * config default values * fix gen config passing * fix rpc generation config
1 parent 8bcfe36 commit 283c407

File tree

6 files changed

+125
-69
lines changed

6 files changed

+125
-69
lines changed

colossalai/inference/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,12 @@ class InferenceConfig(RPC_PARAM):
202202
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
203203
pad_input: bool = False
204204
early_stopping: Optional[bool] = False
205-
top_k: Optional[int] = None
206-
top_p: Optional[float] = None
205+
top_k: Optional[int] = 50
206+
top_p: Optional[float] = 1.0
207207
temperature: Optional[float] = 1.0
208208
no_repeat_ngram_size: Optional[int] = 0
209209
repetition_penalty: Optional[float] = 1.0
210+
forced_eos_token_id: int = None
210211

211212
# speculative decoding configs
212213
max_n_spec_tokens: int = 5

colossalai/inference/core/engine.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
self.init_model(model_or_path, model_policy)
7777

7878
self.generation_config = inference_config.to_generation_config(self.model_config)
79+
self.generation_config_dict = self.generation_config.to_dict()
7980

8081
self.tokenizer = tokenizer
8182
self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -524,12 +525,13 @@ def generate(
524525
Returns:
525526
List[str]: Inference result returned by one generation.
526527
"""
528+
529+
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
530+
prompts = [prompts] if isinstance(prompts, str) else prompts
531+
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
532+
527533
with torch.inference_mode():
528-
if isinstance(prompts, str) and isinstance(request_ids, int):
529-
prompts = [prompts]
530-
request_ids = [request_ids]
531534
if prompts is not None or prompts_token_ids is not None:
532-
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
533535
self.add_request(
534536
request_ids=request_ids,
535537
prompts=prompts,
@@ -543,6 +545,7 @@ def generate(
543545
# intuition: If user provide a generation config, we should replace the existing one.
544546
if generation_config is not None:
545547
self.generation_config = generation_config
548+
self.generation_config_dict = gen_config_dict
546549

547550
if self.use_spec_dec:
548551
assert self.drafter is not None, "Drafter Model is not initialized."
@@ -688,11 +691,12 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
688691
)
689692

690693
batch_token_ids = None
691-
config_dict = self.generation_config.to_dict()
692-
# process repetition_penalty, no_repeat_ngram_size
693-
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
694-
if type in config_dict and config_dict[type] is not None:
695-
batch_token_ids = batch.batch_token_ids
694+
if (
695+
self.generation_config.repetition_penalty != 1.0
696+
or self.generation_config.no_repeat_ngram_size > 0
697+
or self.generation_config.forced_eos_token_id is not None
698+
):
699+
batch_token_ids = batch.batch_token_ids
696700

697701
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
698702
use_cuda_graph = False

colossalai/inference/core/rpc_engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,12 @@ async def step_(self, input_token_ids, input_meta_data: InputMetaData):
257257
assert len(self.workers) == self.tp_size, "init workers first"
258258

259259
init_tasks = [
260-
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
260+
self.async_parallel_wrapper(
261+
worker.execute_model_forward,
262+
input_token_ids,
263+
input_meta_data.to_rpc_param(),
264+
self.generation_config_dict,
265+
)
261266
for worker in self.workers
262267
]
263268
ret = await asyncio.gather(*init_tasks)

colossalai/inference/executor/rpc_worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]
9797
)
9898
logger.info("physical cache init over")
9999

100-
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
100+
def exposed_execute_model_forward(
101+
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
102+
):
101103
# prepare the data for model forward
102104
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
103105
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
@@ -120,7 +122,7 @@ def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_
120122
if self.inference_config.pad_input:
121123
logits = logits[:, -1, :]
122124
next_tokens = search_tokens(
123-
self.inference_config.to_generation_config(self.model_config),
125+
generation_config_param,
124126
logits,
125127
input_meta_data.is_prompts,
126128
input_meta_data.batch_token_ids,

colossalai/inference/logit_processors.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
2-
from typing import List
2+
import logging
3+
from typing import List, Union
34

45
import torch
56
import torch.nn.functional as F
67

7-
_LOGIT_PROCESSOR_MAP = {}
8+
_LOGITS_PROCESSOR_MAP = {}
89

910

10-
def register_logit_processor(process_type):
11+
def register_logits_processor(process_type):
1112
"""
1213
register flops computation function for operation.
1314
"""
1415

1516
def register(func):
16-
global _LOGIT_PROCESSOR_MAP
17-
_LOGIT_PROCESSOR_MAP[process_type] = func
17+
global _LOGITS_PROCESSOR_MAP
18+
_LOGITS_PROCESSOR_MAP[process_type] = func
1819
return func
1920

2021
return register
2122

2223

23-
@register_logit_processor("no_repeat_ngram_size")
24-
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
24+
@register_logits_processor("no_repeat_ngram_size")
25+
def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]):
2526
"""
2627
enforces no repetition of n-grams to avoid repetitions of word sequences.
2728
"""
@@ -52,16 +53,16 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
5253
return logits
5354

5455

55-
@register_logit_processor("repetition_penalty")
56-
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
56+
@register_logits_processor("repetition_penalty")
57+
def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]):
5758
"""
5859
apply the penalty to the tokens present in the prompt.
5960
"""
6061

6162
if not isinstance(penalty, float) or not (penalty > 0):
6263
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
6364

64-
logit_list = []
65+
logits_list = []
6566

6667
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
6768
if penalty != 1.0:
@@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
7172

7273
curretn_socre = torch.gather(current_logit, 0, current_token)
7374
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
74-
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
75+
logits_list.append(current_logit.scatter(0, current_token, curretn_socre))
7576

76-
logits = torch.stack(logit_list)
77+
logits = torch.stack(logits_list)
7778

7879
return logits
7980

8081

81-
@register_logit_processor("temperature")
82-
def temperature_logit_process(logits, temperature: float):
82+
@register_logits_processor("temperature")
83+
def apply_temperature(logits, temperature: float):
8384
"""
8485
apply temperature scaling.
8586
"""
@@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
9394
return logits if temperature == 1.0 else logits / temperature
9495

9596

96-
@register_logit_processor("top_k")
97-
def top_k_logit_processor(logits, top_k: int):
97+
@register_logits_processor("top_k")
98+
def apply_top_k(logits, top_k: int):
9899
"""
99100
top_k logit processor
100101
"""
@@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
107108
return logits
108109

109110

110-
@register_logit_processor("top_p")
111-
def top_p_logit_processor(logits, top_p: float):
111+
@register_logits_processor("top_p")
112+
def apply_top_p(logits, top_p: float):
112113
"""
113114
top_p logit processor
114115
"""
@@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float):
129130
return logits
130131

131132

132-
def logit_processor(processor: str, logits, *args, **kwargs):
133+
@register_logits_processor("forced_eos_token_id")
134+
def apply_forced_eos_token_id(
135+
logits: torch.Tensor,
136+
sequence_lengths: Union[torch.Tensor, List[int]],
137+
max_lengths: Union[torch.Tensor, List[int]],
138+
eos_token_id: Union[int, List[int]],
139+
):
140+
"""
141+
Enforces the specified token as the last generated token when the maximum output length
142+
is reached. Notice that the maximum output lengths for different sequences, even if they're
143+
in the same batch, can be different.
144+
145+
Args:
146+
logits(torch.Tensor): logits
147+
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
148+
max_lengths(torch.Tensor): the maximum length for each sequence
149+
eos_token_id(Union[int, List[int]]): forced eos token id
150+
"""
151+
if isinstance(eos_token_id, int):
152+
eos_token_id = [eos_token_id]
153+
if isinstance(sequence_lengths, torch.Tensor):
154+
sequence_lengths = sequence_lengths.tolist()
155+
if isinstance(max_lengths, torch.Tensor):
156+
max_lengths = max_lengths.tolist()
157+
158+
select_indexes = []
159+
num_sequences = logits.shape[0]
160+
sequence_lengths = sequence_lengths[:num_sequences]
161+
max_lengths = max_lengths[:num_sequences]
162+
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
163+
if sequence_length == max_out_length - 1:
164+
select_indexes.append(i)
165+
if select_indexes:
166+
logits[select_indexes, :] = -float("inf")
167+
logits[select_indexes, eos_token_id] = 0
168+
169+
return logits
170+
171+
172+
def get_logits_processor(processor: str, logits, *args, **kwargs):
133173
"""
134174
do logit process for given logits.
135175
@@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
140180
Returns:
141181
logits after process
142182
"""
143-
if processor not in _LOGIT_PROCESSOR_MAP:
144-
return logits
183+
if processor not in _LOGITS_PROCESSOR_MAP:
184+
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
145185
else:
146-
func = _LOGIT_PROCESSOR_MAP[processor]
186+
func = _LOGITS_PROCESSOR_MAP[processor]
147187
logits = func(logits, *args, **kwargs)
148-
return logits
188+
189+
return logits

colossalai/inference/sampler.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from typing import List, Optional, Tuple
1+
from typing import List, Optional, Tuple, Union
22

33
import torch
44
from transformers.generation import GenerationConfig
55

6-
from colossalai.inference.logit_processors import logit_processor
6+
from colossalai.inference.logit_processors import get_logits_processor
77

88

99
def greedy_sample(
10-
generation_config,
1110
logprobs: torch.Tensor,
1211
) -> torch.Tensor:
1312
"""
@@ -18,7 +17,6 @@ def greedy_sample(
1817

1918

2019
def multinomial_sample(
21-
generation_config,
2220
probs: torch.Tensor,
2321
) -> torch.Tensor:
2422
"""
@@ -29,7 +27,7 @@ def multinomial_sample(
2927

3028

3129
def beam_search_sample(
32-
generation_config,
30+
beam_width: int,
3331
logprobs: torch.Tensor,
3432
is_prompt: bool = False,
3533
) -> List[Tuple[List[int], List[int]]]:
@@ -46,7 +44,6 @@ def beam_search_sample(
4644
# NOTE: this beam search sample function is wrong now.
4745
"""
4846

49-
beam_width = generation_config.num_beams
5047
results = []
5148
if is_prompt:
5249
# Prompt phase.
@@ -64,20 +61,8 @@ def beam_search_sample(
6461
return results
6562

6663

67-
def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
68-
if generation_config.num_beams == 1:
69-
if generation_config.do_sample:
70-
sample_tokens = multinomial_sample(generation_config, probs)
71-
else:
72-
sample_tokens = greedy_sample(generation_config, logprobs)
73-
else:
74-
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)
75-
76-
return sample_tokens
77-
78-
7964
def search_tokens(
80-
generation_config: GenerationConfig,
65+
generation_config: Union[GenerationConfig, dict],
8166
logits,
8267
is_prompt: bool = False,
8368
batch_token_ids: Optional[List[List[int]]] = None,
@@ -86,23 +71,41 @@ def search_tokens(
8671
Sample tokens for finished requests.
8772
"""
8873
# NOTE: need to decide the granularity to process logits (sequence or batch)
89-
config_dict = generation_config.to_dict()
90-
# process repetition_penalty, no_repeat_ngram_size
91-
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
92-
if type in config_dict and config_dict[type] is not None:
93-
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
94-
95-
# do logit processor
96-
if generation_config.do_sample:
97-
# process temperature, top_k, top_p
98-
for type in ["temperature", "top_k", "top_p"]:
99-
if type in config_dict and config_dict[type] is not None:
100-
logits = logit_processor(type, logits, config_dict[type])
74+
75+
# convert GenerationConfig to dict
76+
# temporary fix for compatibility with the usage of RPCInferenceEngine
77+
if isinstance(generation_config, GenerationConfig):
78+
generation_config = generation_config.to_dict()
79+
80+
if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0:
81+
logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids)
82+
if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0:
83+
logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids)
84+
if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None:
85+
sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))]
86+
max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))]
87+
logits = get_logits_processor(
88+
"forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id
89+
)
90+
91+
if generation_config.get("do_sample"):
92+
if (temperature := generation_config.get("temperature", 1.0)) != 1.0:
93+
logits = get_logits_processor("temperature", logits, temperature)
94+
if (top_k := generation_config.get("top_k", 0)) != 0:
95+
logits = get_logits_processor("top_k", logits, top_k)
96+
if (top_p := generation_config.get("top_p", 1.0)) < 1.0:
97+
logits = get_logits_processor("top_p", logits, top_p)
10198

10299
# calculate probs
103100
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
104101
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
105102

106103
# sample the next tokens
107-
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
104+
if generation_config.get("num_beams", 1) != 1:
105+
raise NotImplementedError("Beam search is not supported yet.")
106+
if generation_config.get("do_sample", False):
107+
sample_tokens = multinomial_sample(probs)
108+
else:
109+
sample_tokens = greedy_sample(logprobs)
110+
108111
return sample_tokens

0 commit comments

Comments
 (0)