1
1
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
2
- from typing import List
2
+ import logging
3
+ from typing import List , Union
3
4
4
5
import torch
5
6
import torch .nn .functional as F
6
7
7
- _LOGIT_PROCESSOR_MAP = {}
8
+ _LOGITS_PROCESSOR_MAP = {}
8
9
9
10
10
- def register_logit_processor (process_type ):
11
+ def register_logits_processor (process_type ):
11
12
"""
12
13
register flops computation function for operation.
13
14
"""
14
15
15
16
def register (func ):
16
- global _LOGIT_PROCESSOR_MAP
17
- _LOGIT_PROCESSOR_MAP [process_type ] = func
17
+ global _LOGITS_PROCESSOR_MAP
18
+ _LOGITS_PROCESSOR_MAP [process_type ] = func
18
19
return func
19
20
20
21
return register
21
22
22
23
23
- @register_logit_processor ("no_repeat_ngram_size" )
24
- def no_repeat_ngram_size_logit_process (logits , ngram_size : int , batch_token_ids : List [List [int ]]):
24
+ @register_logits_processor ("no_repeat_ngram_size" )
25
+ def apply_no_repeat_ngram_size (logits , ngram_size : int , batch_token_ids : List [List [int ]]):
25
26
"""
26
27
enforces no repetition of n-grams to avoid repetitions of word sequences.
27
28
"""
@@ -52,16 +53,16 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
52
53
return logits
53
54
54
55
55
- @register_logit_processor ("repetition_penalty" )
56
- def repetition_penalty_logit_process (logits , penalty : float , batch_token_ids : List [List [int ]]):
56
+ @register_logits_processor ("repetition_penalty" )
57
+ def apply_repetition_penalty (logits , penalty : float , batch_token_ids : List [List [int ]]):
57
58
"""
58
59
apply the penalty to the tokens present in the prompt.
59
60
"""
60
61
61
62
if not isinstance (penalty , float ) or not (penalty > 0 ):
62
63
raise ValueError (f"'penalty={ penalty } ' has to be a strictly positive float and greater than 0." )
63
64
64
- logit_list = []
65
+ logits_list = []
65
66
66
67
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
67
68
if penalty != 1.0 :
@@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
71
72
72
73
curretn_socre = torch .gather (current_logit , 0 , current_token )
73
74
curretn_socre = torch .where (curretn_socre < 0 , curretn_socre * penalty , curretn_socre / penalty )
74
- logit_list .append (current_logit .scatter (0 , current_token , curretn_socre ))
75
+ logits_list .append (current_logit .scatter (0 , current_token , curretn_socre ))
75
76
76
- logits = torch .stack (logit_list )
77
+ logits = torch .stack (logits_list )
77
78
78
79
return logits
79
80
80
81
81
- @register_logit_processor ("temperature" )
82
- def temperature_logit_process (logits , temperature : float ):
82
+ @register_logits_processor ("temperature" )
83
+ def apply_temperature (logits , temperature : float ):
83
84
"""
84
85
apply temperature scaling.
85
86
"""
@@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
93
94
return logits if temperature == 1.0 else logits / temperature
94
95
95
96
96
- @register_logit_processor ("top_k" )
97
- def top_k_logit_processor (logits , top_k : int ):
97
+ @register_logits_processor ("top_k" )
98
+ def apply_top_k (logits , top_k : int ):
98
99
"""
99
100
top_k logit processor
100
101
"""
@@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
107
108
return logits
108
109
109
110
110
- @register_logit_processor ("top_p" )
111
- def top_p_logit_processor (logits , top_p : float ):
111
+ @register_logits_processor ("top_p" )
112
+ def apply_top_p (logits , top_p : float ):
112
113
"""
113
114
top_p logit processor
114
115
"""
@@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float):
129
130
return logits
130
131
131
132
132
- def logit_processor (processor : str , logits , * args , ** kwargs ):
133
+ @register_logits_processor ("forced_eos_token_id" )
134
+ def apply_forced_eos_token_id (
135
+ logits : torch .Tensor ,
136
+ sequence_lengths : Union [torch .Tensor , List [int ]],
137
+ max_lengths : Union [torch .Tensor , List [int ]],
138
+ eos_token_id : Union [int , List [int ]],
139
+ ):
140
+ """
141
+ Enforces the specified token as the last generated token when the maximum output length
142
+ is reached. Notice that the maximum output lengths for different sequences, even if they're
143
+ in the same batch, can be different.
144
+
145
+ Args:
146
+ logits(torch.Tensor): logits
147
+ sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
148
+ max_lengths(torch.Tensor): the maximum length for each sequence
149
+ eos_token_id(Union[int, List[int]]): forced eos token id
150
+ """
151
+ if isinstance (eos_token_id , int ):
152
+ eos_token_id = [eos_token_id ]
153
+ if isinstance (sequence_lengths , torch .Tensor ):
154
+ sequence_lengths = sequence_lengths .tolist ()
155
+ if isinstance (max_lengths , torch .Tensor ):
156
+ max_lengths = max_lengths .tolist ()
157
+
158
+ select_indexes = []
159
+ num_sequences = logits .shape [0 ]
160
+ sequence_lengths = sequence_lengths [:num_sequences ]
161
+ max_lengths = max_lengths [:num_sequences ]
162
+ for i , (sequence_length , max_out_length ) in enumerate (zip (sequence_lengths , max_lengths )):
163
+ if sequence_length == max_out_length - 1 :
164
+ select_indexes .append (i )
165
+ if select_indexes :
166
+ logits [select_indexes , :] = - float ("inf" )
167
+ logits [select_indexes , eos_token_id ] = 0
168
+
169
+ return logits
170
+
171
+
172
+ def get_logits_processor (processor : str , logits , * args , ** kwargs ):
133
173
"""
134
174
do logit process for given logits.
135
175
@@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
140
180
Returns:
141
181
logits after process
142
182
"""
143
- if processor not in _LOGIT_PROCESSOR_MAP :
144
- return logits
183
+ if processor not in _LOGITS_PROCESSOR_MAP :
184
+ logging . warning ( f"Unsupported processor { processor } . Fall back to the original logits." )
145
185
else :
146
- func = _LOGIT_PROCESSOR_MAP [processor ]
186
+ func = _LOGITS_PROCESSOR_MAP [processor ]
147
187
logits = func (logits , * args , ** kwargs )
148
- return logits
188
+
189
+ return logits
0 commit comments