Skip to content

Commit 536e6a0

Browse files
committed
Bump transformers minor version; fix TypicalLogitsWarper
1 parent a39b8a8 commit 536e6a0

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

server/poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ bitsandbytes = { version = "^0.41.1", optional = true }
1919
scipy = { version = "^1.11.2", optional = true }
2020
safetensors = "^0.4.0"
2121
sentencepiece = "^0.1.99"
22-
transformers = "4.34.0"
22+
transformers = "4.34.1"
2323
optimum = { version = "1.13.2", extras = ["onnxruntime-gpu"], optional = true }
2424
onnxruntime = { version = "1.16.0", optional = true }
2525
onnxruntime-gpu = { version = "1.16.0", optional = true }

server/text_generation_server/utils/logits_process.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,8 @@ def filter(self, indices):
419419
return None
420420

421421

422-
# This is a fixed version of the class in transformers. Can be moved once we contribute back the fix and upgrade.
422+
# This is a fixed version of the class in transformers, see https://github.com/huggingface/transformers/pull/26579.
423+
# Can be removed after upgrading to transformers v4.35+
423424
class TypicalLogitsWarper(LogitsWarper):
424425
r"""
425426
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
@@ -456,8 +457,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
456457
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
457458

458459
# Remove tokens with cumulative mass above the threshold
459-
last_ind = (cumulative_probs < self.mass).sum(dim=1)
460-
last_ind.clamp_(0, sorted_scores.shape[-1] - 1)
460+
last_ind = (cumulative_probs < self.mass).sum(dim=1) - 1
461+
last_ind.clamp_(min=0)
461462
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
462463
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
463464
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0

0 commit comments

Comments
 (0)