Skip to content

Commit 7a3301f

Browse files
authored
fix: add_special_tokens in tokenize (#144)
#66 made add_special_tokens true by default but its behaviour isn't replicated in /tokenize resulting in a different token count if ADD_SPECIAL_TOKENS is false. This PR fixes that by passing it in /tokenize and adds a test for the tokenize method. I can follow this up with another test that compares the token count between the methods if required but otherwise this closes #141.
1 parent 896db8b commit 7a3301f

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

src/vllm_tgis_adapter/grpc/grpc_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,9 @@ async def Tokenize(
855855
# other threads
856856
for req in request.requests:
857857
batch_encoding = tokenizer.encode_plus(
858-
text=req.text, return_offsets_mapping=request.return_offsets
858+
text=req.text,
859+
return_offsets_mapping=request.return_offsets,
860+
add_special_tokens=ADD_SPECIAL_TOKENS,
859861
)
860862

861863
# Tokenize the input text

tests/test_grpc_server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def test_generation_request(grpc_client):
2525
assert response.stop_reason is not None
2626

2727

28+
def test_tokenize_request(grpc_client):
29+
response_tokenize = grpc_client.make_request_tokenize(
30+
text="Please answer the following question.\nhow far is Paris from New York?",
31+
)
32+
33+
assert response_tokenize.token_count
34+
35+
2836
def test_generation_request_stream(grpc_client):
2937
streaming_response = grpc_client.make_request_stream(
3038
"The answer to life the universe and everything is ",

tests/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111

1212
from vllm_tgis_adapter.grpc.pb.generation_pb2 import (
1313
BatchedGenerationRequest,
14+
BatchedTokenizeRequest,
1415
GenerationRequest,
1516
ModelInfoRequest,
1617
Parameters,
1718
SingleGenerationRequest,
1819
StoppingCriteria,
20+
TokenizeRequest,
1921
)
2022
from vllm_tgis_adapter.grpc.pb.generation_pb2_grpc import GenerationServiceStub
2123

@@ -25,6 +27,7 @@
2527
from vllm_tgis_adapter.grpc.pb.generation_pb2 import (
2628
GenerationResponse,
2729
ModelInfoResponse,
30+
TokenizeResponse,
2831
)
2932

3033
_T = TypeVar("_T")
@@ -173,6 +176,30 @@ def make_request_stream(
173176
except grpc._channel._MultiThreadedRendezvous as exc: # noqa: SLF001
174177
raise RuntimeError(exc.details()) from exc
175178

179+
def make_request_tokenize(
180+
self,
181+
text: str | list[str],
182+
model_id: str | None = None,
183+
adapter_id: str | None = None,
184+
) -> TokenizeResponse | Sequence[TokenizeResponse]:
185+
if single_request := isinstance(text, str):
186+
text = [text]
187+
188+
request = BatchedTokenizeRequest(
189+
model_id=model_id,
190+
requests=[TokenizeRequest(text=piece) for piece in text],
191+
adapter_id=adapter_id,
192+
)
193+
194+
response = self.generation_service_stub.Tokenize(
195+
request=request,
196+
)
197+
198+
if single_request:
199+
return response.responses[0]
200+
201+
return response.responses
202+
176203
def __enter__(self): # noqa: D105
177204
return self
178205

0 commit comments

Comments
 (0)