Skip to content

Commit a8926f6

Browse files
committed
Revert TypicalLogitsWarper change for now
1 parent 536e6a0 commit a8926f6

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

integration_tests/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
gen-client:
22
# Compile protos
3-
pip install grpcio-tools==1.58.0 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
3+
pip install grpcio-tools==1.59.0 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
44
mkdir text_generation_tests/pb || true
55
python -m grpc_tools.protoc -I../proto --python_out=text_generation_tests/pb \
66
--grpc_python_out=text_generation_tests/pb --mypy_out=text_generation_tests/pb ../proto/generation.proto

server/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ include Makefile-flash-att-v2
33

44
gen-server:
55
# Compile protos
6-
pip install grpcio-tools==1.58.0 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
6+
pip install grpcio-tools==1.59.0 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
77
mkdir text_generation_server/pb || true
88
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \
99
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto
@@ -15,7 +15,7 @@ TORCH_VERSION := 2.0.0+cu118
1515

1616
install-torch:
1717
# Install specific version of torch
18-
pip install ninja==1.11.1 torch==$(TORCH_VERSION) --extra-index-url $(TORCH_URL) --no-cache-dir
18+
pip install ninja==1.11.1.1 torch==$(TORCH_VERSION) --extra-index-url $(TORCH_URL) --no-cache-dir
1919

2020
install-deepspeed:
2121
# Install specific version of deepspeed

server/text_generation_server/utils/logits_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
457457
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
458458

459459
# Remove tokens with cumulative mass above the threshold
460-
last_ind = (cumulative_probs < self.mass).sum(dim=1) - 1
461-
last_ind.clamp_(min=0)
460+
last_ind = (cumulative_probs < self.mass).sum(dim=1)
461+
last_ind.clamp_(max=sorted_scores.shape[-1] - 1)
462462
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
463463
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
464464
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0

0 commit comments

Comments
 (0)