Skip to content

Commit 690ef8b

Browse files
m-misiurababerabb
andauthored
Leverage vllm's tokenizer_info endpoint to avoid manual duplication (#3185)
* ✨ added an approach to use tokenizer_info endpoint from vllm Signed-off-by: m-misiura <[email protected]> * 🚧 removed all auto-detection and tokenization logic from `LocalChatCompletion` * pacify pre-commit --------- Signed-off-by: m-misiura <[email protected]> Co-authored-by: Baber <[email protected]>
1 parent 655718d commit 690ef8b

File tree

5 files changed

+560
-15
lines changed

5 files changed

+560
-15
lines changed

lm_eval/models/api_models.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
# however the requests can be sent as a string if the API doesn't support token inputs.
115115
# use tokenized_requests=False
116116
tokenizer_backend: Optional[
117-
Literal["tiktoken", "huggingface", "None", "none"]
117+
Literal["tiktoken", "huggingface", "remote", "None", "none"]
118118
] = "huggingface",
119119
truncate: bool = False,
120120
# number of concurrent requests. More useful if not batching
@@ -132,6 +132,8 @@ def __init__(
132132
revision: Optional[str] = "main",
133133
use_fast_tokenizer: bool = True,
134134
verify_certificate: bool = True,
135+
ca_cert_path: Optional[str] = None,
136+
auth_token: Optional[str] = None,
135137
eos_string: str = None,
136138
# timeout in seconds
137139
timeout: int = 300,
@@ -182,6 +184,8 @@ def __init__(
182184
self.tokenized_requests = tokenized_requests
183185
self.max_retries = int(max_retries)
184186
self.verify_certificate = verify_certificate
187+
self.ca_cert_path = ca_cert_path
188+
self.auth_token = auth_token
185189
self._eos_string = eos_string
186190
self.timeout = int(timeout)
187191
self.max_images = int(max_images)
@@ -218,6 +222,21 @@ def __init__(
218222
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
219223
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
220224
)
225+
elif self.tokenizer_backend == "remote":
226+
from lm_eval.utils import RemoteTokenizer
227+
228+
if not self.base_url:
229+
raise ValueError(
230+
"base_url is required for remote tokenizer backend"
231+
)
232+
self.tokenizer = RemoteTokenizer(
233+
self.base_url,
234+
self.timeout,
235+
self.verify_certificate,
236+
self.ca_cert_path,
237+
self.auth_token,
238+
)
239+
eval_logger.info(f"Using remote tokenizer from {self.base_url}")
221240
else:
222241
import transformers
223242

@@ -310,7 +329,7 @@ def tokenizer_name(self) -> str:
310329

311330
def apply_chat_template(
312331
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
313-
) -> Union[str, JsonChatStr]:
332+
) -> Union[str, JsonChatStr, List[Dict]]:
314333
"""Applies a chat template to a list of chat history between user and model."""
315334
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
316335
return self.tokenizer.apply_chat_template(
@@ -319,6 +338,8 @@ def apply_chat_template(
319338
add_generation_prompt=add_generation_prompt,
320339
continue_final_message=not add_generation_prompt,
321340
)
341+
elif self.tokenizer_backend == "remote" and self.tokenized_requests:
342+
return chat_history
322343
else:
323344
# bit of a hack. We'll load back before sending to the API
324345
return JsonChatStr(
@@ -337,6 +358,8 @@ def eot_token_id(self) -> Optional[int]:
337358
return self.tokenizer.eos_token_id
338359
elif self.tokenizer_backend == "tiktoken":
339360
return self.tokenizer.eot_token
361+
elif self.tokenizer_backend == "remote":
362+
return self.tokenizer.eos_token_id
340363

341364
@cached_property
342365
def eos_string(self) -> Optional[str]:
@@ -347,6 +370,8 @@ def eos_string(self) -> Optional[str]:
347370
return self.tokenizer.eos_token
348371
elif self.tokenizer_backend == "tiktoken":
349372
return self.tokenizer.decode([self.tokenizer.eot_token])
373+
elif self.tokenizer_backend == "remote":
374+
return self.tokenizer.eos_token
350375
else:
351376
eval_logger.warning(
352377
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
@@ -364,6 +389,8 @@ def prefix_token_id(self) -> Optional[int]:
364389
if self.tokenizer.bos_token_id is not None:
365390
return self.tokenizer.bos_token_id
366391
return self.tokenizer.eos_token_id
392+
elif self.tokenizer_backend == "remote":
393+
return self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
367394
else:
368395
return self.tokenizer.eot_token
369396

@@ -396,7 +423,19 @@ def tok_encode(
396423
encoding = encoding[-left_truncate_len:]
397424

398425
return encoding
426+
elif self.tokenizer_backend == "remote":
427+
if isinstance(string, str):
428+
encoding = self.tokenizer.encode(string)
429+
else:
430+
encoding = [self.tokenizer.encode(s) for s in string]
399431

432+
if left_truncate_len:
433+
if isinstance(string, str):
434+
encoding = encoding[-left_truncate_len:]
435+
else:
436+
encoding = [enc[-left_truncate_len:] for enc in encoding]
437+
438+
return encoding
400439
else:
401440
try:
402441
encoding = self.tokenizer.encode(string)
@@ -409,6 +448,8 @@ def decode_batch(self, tokens: List[List[int]]) -> List[str]:
409448
return self.tokenizer.batch_decode(tokens)
410449
elif self.tokenizer_backend == "tiktoken":
411450
return self.tokenizer.decode_batch(tokens)
451+
elif self.tokenizer_backend == "remote":
452+
return self.tokenizer.batch_decode(tokens)
412453

413454
def model_call(
414455
self,

lm_eval/models/openai_completions.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,46 @@
1616
class LocalCompletionsAPI(TemplateAPI):
1717
def __init__(
1818
self,
19-
base_url: str = None,
20-
tokenizer_backend: str = "huggingface",
19+
base_url=None,
20+
tokenizer_backend="auto",
21+
verify_certificate=True,
22+
ca_cert_path=None,
23+
auth_token=None,
2124
**kwargs,
2225
):
26+
# Auto-detect tokenizer backend
27+
if tokenizer_backend == "auto":
28+
if base_url:
29+
from lm_eval.utils import check_remote_tokenizer_support
30+
31+
if check_remote_tokenizer_support(
32+
base_url,
33+
verify_certificate=verify_certificate,
34+
ca_cert_path=ca_cert_path,
35+
auth_token=auth_token,
36+
):
37+
eval_logger.info(
38+
"Auto-detected remote tokenizer support. Using remote tokenizer backend."
39+
)
40+
tokenizer_backend = "remote"
41+
else:
42+
eval_logger.info(
43+
"Remote tokenizer not supported. Using huggingface tokenizer backend."
44+
)
45+
tokenizer_backend = "huggingface"
46+
else:
47+
eval_logger.warning(
48+
"No base_url provided. Using huggingface tokenizer backend."
49+
)
50+
tokenizer_backend = "huggingface"
51+
2352
super().__init__(
24-
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
53+
base_url=base_url,
54+
tokenizer_backend=tokenizer_backend,
55+
verify_certificate=verify_certificate,
56+
ca_cert_path=ca_cert_path,
57+
auth_token=auth_token,
58+
**kwargs,
2559
)
2660

2761
def _create_payload(
@@ -106,20 +140,28 @@ def api_key(self):
106140

107141
@register_model("local-chat-completions")
108142
class LocalChatCompletion(LocalCompletionsAPI):
143+
"""
144+
Minimal chat-completions wrapper.
145+
- Only accepts messages as list[dict].
146+
- No tokenization or template logic.
147+
- Use with --apply_chat_template or ensure upstream formats messages correctly.
148+
"""
149+
109150
def __init__(
110151
self,
111-
base_url: str = None,
112-
tokenizer_backend: str = None,
113-
tokenized_requests: bool = False,
152+
base_url=None,
153+
verify_certificate=True,
154+
ca_cert_path=None,
155+
auth_token=None,
114156
**kwargs,
115157
):
116-
eval_logger.warning(
117-
"chat-completions endpoint requires the `--apply_chat_template` flag."
118-
)
119158
super().__init__(
120159
base_url=base_url,
121-
tokenizer_backend=tokenizer_backend,
122-
tokenized_requests=tokenized_requests,
160+
tokenizer_backend=None,
161+
tokenized_requests=None,
162+
verify_certificate=verify_certificate,
163+
ca_cert_path=ca_cert_path,
164+
auth_token=auth_token,
123165
**kwargs,
124166
)
125167
if self._batch_size > 1:
@@ -137,9 +179,13 @@ def _create_payload(
137179
eos=None,
138180
**kwargs,
139181
) -> dict:
140-
assert type(messages) is not str, (
141-
"chat-completions require the --apply_chat_template flag."
182+
assert isinstance(messages, list) and all(
183+
isinstance(m, dict) for m in messages
184+
), (
185+
"LocalChatCompletion expects messages as list[dict]. "
186+
"If you see this error, ensure --apply_chat_template is set or upstream code formats messages correctly."
142187
)
188+
gen_kwargs = gen_kwargs or {}
143189
gen_kwargs.pop("do_sample", False)
144190
if "max_tokens" in gen_kwargs:
145191
max_tokens = gen_kwargs.pop("max_tokens")

0 commit comments

Comments
 (0)