Skip to content

Commit 642041d

Browse files
committed
Some minor token processing logic cleanup
From recent code observations
1 parent 1feed99 commit 642041d

File tree

3 files changed

+39
-36
lines changed

3 files changed

+39
-36
lines changed

router/src/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ impl<B: BatchType> BatchConfigValidator<B> {
207207
);
208208
if max_prefill_weight < single_request_prefill_weight {
209209
panic!(
210-
"max_prefill_weight ({}) not large enough for max_sequence_length ({}",
210+
"max_prefill_weight ({}) not large enough for max_sequence_length ({})",
211211
max_prefill_weight, max_sequence_length
212212
)
213213
}

server/text_generation_server/utils/logits_process.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso
121121

122122
def filter(self, indices):
123123
self.penalty = [self.penalty[i] for i in indices]
124-
if any([x != 1.0 for x in self.penalty]):
125-
self.penalty_tensor = self.penalty_tensor[indices]
126-
return self
127-
return None
124+
if all(x == 1.0 for x in self.penalty):
125+
return None
126+
self.penalty_tensor = self.penalty_tensor[indices]
127+
return self
128128

129129

130130
class HeterogeneousTemperatureLogitsWarper:
@@ -152,10 +152,10 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso
152152

153153
def filter(self, indices):
154154
self.temperature = [self.temperature[i] for i in indices]
155-
if any([x != 1.0 for x in self.temperature]):
156-
self.temperature_tensor = self.temperature_tensor[indices]
157-
return self
158-
return None
155+
if all(x == 1.0 for x in self.temperature):
156+
return None
157+
self.temperature_tensor = self.temperature_tensor[indices]
158+
return self
159159

160160

161161
class HeterogeneousTopPLogitsWarper(LogitsWarper):
@@ -211,10 +211,10 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso
211211

212212
def filter(self, indices):
213213
self.top_p = [self.top_p[i] for i in indices]
214-
if any([x < 1.0 for x in self.top_p]):
215-
self.top_p_opposite = self.top_p_opposite[indices]
216-
return self
217-
return None
214+
if all(x == 1.0 for x in self.top_p):
215+
return None
216+
self.top_p_opposite = self.top_p_opposite[indices]
217+
return self
218218

219219

220220
class HeterogeneousTopKLogitsWarper(LogitsWarper):
@@ -270,7 +270,7 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso
270270
top_k = self.top_k_tensor
271271

272272
# Get the kth score for each member of the batch
273-
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
273+
kth_scores = torch.gather(torch.topk(scores, max_top_k).values, 1, top_k)
274274

275275
# Mask member of kth_scores that do not want to use top_k warping
276276
if self.top_k_disabled_mask is not None:
@@ -376,16 +376,16 @@ def filter(self, indices):
376376
self.mass = [self.mass[i] for i in indices]
377377
disabled = [x == 1.0 for x in self.mass]
378378

379-
if not all(disabled):
380-
self.mass_tensor = self.mass_tensor[indices]
379+
if all(disabled):
380+
return None
381381

382-
if self.disabled_mask is not None:
383-
self.disabled_mask = (
384-
self.disabled_mask[indices] if any(disabled) else None
385-
)
382+
self.mass_tensor = self.mass_tensor[indices]
386383

387-
return self
388-
return None
384+
if self.disabled_mask is not None:
385+
self.disabled_mask = (
386+
self.disabled_mask[indices] if any(disabled) else None
387+
)
388+
return self
389389

390390

391391
# NB: This class is not currently used.

server/text_generation_server/utils/tokens.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,13 @@ def __init__(
202202
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
203203

204204
if any(x < 1.0 for x in top_p):
205+
#assert all(x != 0 for x in top_p)
205206
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
206-
# We specifically exclude degenerate case 0, we devolves into greedy decoding,
207-
# to align with the nonvectorized typical logit warping behavior.
208-
if any(0.0 < x < 1.0 for x in typical_p):
209-
corrected_probs = [p if p != 0 else 1 for p in typical_p]
210-
warpers.append(HeterogeneousTypicalLogitsWarper(corrected_probs, dtype, device))
207+
208+
if any(x < 1.0 for x in typical_p):
209+
#assert all(x != 0 for x in typical_p)
210+
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
211+
211212
self.choice = HeterogeneousSampling(do_sample, seeds, device)
212213
else:
213214
self.choice = Greedy()
@@ -279,8 +280,10 @@ def from_pb(
279280
temperature=[pb_.temperature for pb_ in pb],
280281
repetition_penalty=[pb_.repetition_penalty if pb_.HasField('repetition_penalty') else 1.0 for pb_ in pb],
281282
top_k=[pb_.top_k for pb_ in pb],
282-
top_p=[pb_.top_p for pb_ in pb],
283-
typical_p=[pb_.typical_p for pb_ in pb],
283+
# Ensure that default (zero) values for top_p and typical_p are converted to 1.0
284+
# (which corresponds to disabled in both cases)
285+
top_p=[pb_.top_p if pb_.top_p > 0 else 1.0 for pb_ in pb],
286+
typical_p=[pb_.typical_p if pb_.typical_p > 0 else 1.0 for pb_ in pb],
284287
length_penalty=[
285288
(pb_.length_penalty.start_index, pb_.length_penalty.decay_factor)
286289
if pb_.HasField('length_penalty') else None for pb_ in pb
@@ -310,7 +313,7 @@ def filter(self, indices):
310313
self.current_tokens = [self.current_tokens[i] for i in indices]
311314
self.min_new_tokens = [self.min_new_tokens[i] for i in indices]
312315
self.length_penalty = [self.length_penalty[i] for i in indices]
313-
self.return_logprobs = [self.return_logprobs[i] for i in indices]
316+
self.return_logprobs = [self.return_logprobs[i] for i in indices]
314317

315318
if any(self.do_sample):
316319
self.choice.filter(indices)
@@ -370,12 +373,13 @@ def filter(self, indices):
370373
self.samplings = [self.samplings[i] for i in indices]
371374
return self
372375

376+
373377
# Extract requested token information from model output
374378
def get_token_info(
375379
request: generate_pb2.Request,
376380
scores: torch.Tensor, # Assumes shape is [1, vocab_size]
377381
next_token: torch.Tensor,
378-
logprobs: Optional[torch.Tensor], # Assumes shape matches logits
382+
logprobs: Optional[torch.Tensor], # Assumes shape matches logits
379383
) -> TokenInfo:
380384
next_token = next_token.item()
381385
token_info = TokenInfo(request_id=request.id, token_id=next_token)
@@ -391,9 +395,8 @@ def get_token_info(
391395
# Ensure top_n doesn't exceed vocab size
392396
top_n = min(return_top_n, flat_scores.size(-1))
393397
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
394-
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
395-
if nth_highest == -float('inf'):
396-
nth_highest = torch.finfo(flat_scores.dtype).min
398+
nth_highest = flat_scores.topk(top_n).values[-1]
399+
torch.nan_to_num_(nth_highest, neginf=torch.finfo(torch.float).min)
397400
# Get indices (token ids) of all scores >= nth highest value,
398401
# cap length at 4 * top_n as a precaution
399402
top_n_indices = (flat_scores >= nth_highest).nonzero().squeeze(-1)[:(top_n * 4)]
@@ -407,7 +410,7 @@ def get_token_info(
407410
# Token ranks if requested
408411
if request.details.ranks:
409412
#TODO if we're also returning top_n perhaps search those first
410-
token_info.rank = (scores > scores[0][next_token]).sum() + 1
413+
token_info.rank = (scores > scores[0, next_token]).sum() + 1
411414

412415
return token_info
413416

@@ -447,7 +450,7 @@ def get_input_tokens_info(request, input_token_ids, all_input_logits) -> InputTo
447450
# Ensure top_n doesn't exceed vocab size
448451
top_n = min(top_n, all_input_logits.size(-1))
449452
# Get the nth highest value for each input token's set of logits
450-
nth_highest_values = torch.topk(all_input_logits, top_n)[0][..., -1, None]
453+
nth_highest_values = torch.topk(all_input_logits, top_n).values[..., -1, None]
451454
# Construct bool tensor marking all scores >= nth highest value for each token
452455
diff = (all_input_logits >= nth_highest_values)
453456
# Gather set of marked indices for each token (correspond to top token ids)

0 commit comments

Comments
 (0)