1
1
from __future__ import annotations
2
2
3
+ import math
3
4
import copy
4
5
import torch
5
6
import inspect
@@ -56,9 +57,9 @@ def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
56
57
57
58
class TemperatureLogitsWarper :
58
59
def __init__ (self , temperature : float ):
59
-
60
60
if not (temperature > 0 ):
61
- raise ValueError (f"`temperature` (={ temperature } ) must be positive temperature > 0" )
61
+ raise ValueError (f"`temperature` (={ temperature } ) must be a positive number > 0" )
62
+
62
63
self .temperature = temperature
63
64
64
65
def __call__ (self , scores : torch .FloatTensor ) -> torch .FloatTensor :
@@ -86,10 +87,30 @@ def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
86
87
scores_processed = scores .masked_fill (indices_to_remove , self .filter_value )
87
88
return scores_processed
88
89
90
+ class MinLengthLogitsProcessor :
91
+ def __init__ (self , min_length : int , eos_token_id : torch .Tensor ):
92
+ self .min_length = min_length
93
+ self .eos_token_id = eos_token_id
94
+
95
+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
96
+
97
+ if input_ids is None :
98
+ return scores
99
+
100
+ vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
101
+ eos_token_mask = torch .isin (vocab_tensor , self .eos_token_id )
102
+ scores_processed = scores .clone ()
103
+ if input_ids .shape [- 1 ] < self .min_length :
104
+ scores_processed = torch .where (eos_token_mask , - math .inf , scores )
105
+ return scores_processed
106
+
89
107
def get_logits_processing (config : GenerationConfig ):
90
108
# TODO: add support for beam search with diversity penalty
91
109
logits_processors = []
92
110
111
+ if config ._eos_token_tensor is not None and config .min_length > 1 :
112
+ logits_processors .append (MinLengthLogitsProcessor (config .min_length , config ._eos_token_tensor ))
113
+
93
114
if config .top_k is not None and config .top_k != 0 :
94
115
logits_processors .append (TopKLogits (config .top_k ))
95
116
@@ -101,28 +122,59 @@ def get_logits_processing(config: GenerationConfig):
101
122
102
123
return logits_processors
103
124
104
- def apply_logits_processing (logits , logits_processing_list , ** kwargs ):
125
+ def apply_logits_processing (input_ids , logits , logits_processing_list , ** kwargs ):
105
126
for process in logits_processing_list :
106
127
func_args = inspect .signature (process .__call__ ).parameters
107
- if not all (arg in kwargs for arg in list (func_args .keys ())[1 :]):
128
+ if not all (arg in kwargs for arg in list (func_args .keys ())[3 :]):
108
129
raise ValueError (
109
130
f"Make sure that all the required parameters: { list (func_args .keys ())} for "
110
131
f"{ process .__class__ } are passed to the logits processor."
111
132
)
112
- logits = process (logits , ** kwargs )
133
+ if "input_ids" in func_args :
134
+ logits = process (input_ids , logits )
135
+ else :
136
+ logits = process (logits , ** kwargs )
113
137
return logits
114
138
115
- def check_stopping_criteria (input_ids : torch .Tensor , max_length : int , eos_token ):
139
+ def check_stopping_strings (input_ids : torch .Tensor , stop_strings : list ) -> torch .BoolTensor :
140
+ # stop_strings must be a list of lists: List[List[], List[]]
141
+
142
+ device = input_ids .device
143
+ batch_size , seq_len = input_ids .shape
144
+ finished = torch .zeros (batch_size , dtype = torch .bool , device = device )
145
+
146
+ for b in range (batch_size ):
147
+ row = input_ids [b ]
148
+ # check each stop token sequence
149
+ for stop_ids in stop_strings :
150
+ n = len (stop_ids )
151
+ if n == 0 or n > seq_len :
152
+ continue
153
+ # compare tail of the generated ids to the stop sequence
154
+ if torch .all (row [- n :] == torch .tensor (stop_ids , device = device , dtype = row .dtype )):
155
+ finished [b ] = True
156
+ break
157
+
158
+ return finished
159
+
160
+ def check_stopping_criteria (input_ids : torch .Tensor , max_length : int , eos_token , stop_strings : tuple = None ):
161
+
162
+ device = input_ids .device
116
163
117
164
if not isinstance (eos_token , torch .Tensor ):
118
- eos_token = torch .tensor (eos_token , device = input_ids . device )
165
+ eos_token = torch .tensor (eos_token , device = device )
119
166
120
167
max_len_done = input_ids .shape [1 ] >= max_length
121
168
122
169
eos_done = torch .isin (input_ids [:, - 1 ], eos_token )
123
170
124
- # finished either by lenght or eos
125
- finished_mask = max_len_done | eos_done
171
+ if stop_strings is not None :
172
+ stop_done = check_stopping_strings (input_ids , stop_strings )
173
+ else :
174
+ stop_done = torch .zeros (input_ids .size (0 ), dtype = torch .bool , device = device )
175
+
176
+ # finished either by lenght or eos or stop strings
177
+ finished_mask = max_len_done | eos_done | stop_done
126
178
127
179
unfinished_mask = ~ finished_mask
128
180
0 commit comments