Skip to content

Commit 039debc

Browse files
authored
[FIX] Inference providers (huggingface#701)
* added option to bill to org in inference providers * removed tokenizer logging
1 parent 40626e7 commit 039debc

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ dependencies = [
5656
# Base dependencies
5757
"transformers>=4.38.0",
5858
"accelerate",
59-
"huggingface_hub[hf_xet]",
59+
"huggingface_hub[hf_xet]>=0.30.2",
6060
"torch>=2.0,<3.0",
6161
"GitPython>=3.1.41", # for logging
6262
"datasets>=3.5.0",

src/lighteval/models/endpoints/inference_providers_model.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import yaml
2828
from huggingface_hub import AsyncInferenceClient, ChatCompletionOutput
29+
from huggingface_hub.errors import HfHubHTTPError
2930
from pydantic import NonNegativeInt
3031
from tqdm import tqdm
3132
from tqdm.asyncio import tqdm as async_tqdm
@@ -60,13 +61,15 @@ class InferenceProvidersModelConfig(ModelConfig):
6061
provider: Name of the inference provider
6162
timeout: Request timeout in seconds
6263
proxies: Proxy configuration for requests
64+
org_to_bill: Organisation to bill if not the user
6365
generation_parameters: Parameters for text generation
6466
"""
6567

6668
model_name: str
6769
provider: str
6870
timeout: int | None = None
6971
proxies: Any | None = None
72+
org_to_bill: str | None = None
7073
parallel_calls_count: NonNegativeInt = 10
7174

7275
@classmethod
@@ -78,12 +81,14 @@ def from_path(cls, path):
7881
provider = config.get("provider", None)
7982
timeout = config.get("timeout", None)
8083
proxies = config.get("proxies", None)
84+
org_to_bill = config.get("org_to_bill", None)
8185
generation_parameters = GenerationParameters.from_dict(config)
8286
return cls(
8387
model=model_name,
8488
provider=provider,
8589
timeout=timeout,
8690
proxies=proxies,
91+
org_to_bill=org_to_bill,
8792
generation_parameters=generation_parameters,
8893
)
8994

@@ -121,19 +126,20 @@ def __init__(self, config: InferenceProvidersModelConfig) -> None:
121126
provider=self.provider,
122127
timeout=config.timeout,
123128
proxies=config.proxies,
129+
bill_to=config.org_to_bill,
124130
)
125-
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
131+
try:
132+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
133+
except HfHubHTTPError:
134+
logger.warning("Could not load model's tokenizer: {e}.")
135+
self._tokenizer = None
126136

127137
def _encode(self, text: str) -> dict:
128-
enc = self._tokenizer(text=text)
129-
return enc
130-
131-
def tok_encode(self, text: str | list[str]):
132-
if isinstance(text, list):
133-
toks = [self._encode(t["content"]) for t in text]
134-
toks = [tok for tok in toks if tok]
135-
return toks
136-
return self._encode(text)
138+
if self._tokenizer:
139+
enc = self._tokenizer(text=text)
140+
return enc
141+
logger.warning("Tokenizer is not loaded, can't encore the text, returning it as such.")
142+
return text
137143

138144
async def __call_api(self, prompt: List[dict], num_samples: int) -> Optional[ChatCompletionOutput]:
139145
"""Make API call with exponential backoff retry logic.
@@ -204,9 +210,6 @@ def greedy_until(
204210
Returns:
205211
list[GenerativeResponse]: list of generated responses.
206212
"""
207-
for request in requests:
208-
request.tokenized_context = self.tok_encode(request.context)
209-
210213
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
211214
results = []
212215

@@ -247,7 +250,11 @@ def add_special_tokens(self) -> bool:
247250
@property
248251
def max_length(self) -> int:
249252
"""Return the maximum sequence length of the model."""
250-
return self._tokenizer.model_max_length
253+
try:
254+
return self._tokenizer.model_max_length
255+
except AttributeError:
256+
logger.warning("Tokenizer was not correctly loaded. Max model context length is assumed to be 30K tokens")
257+
return 30000
251258

252259
def loglikelihood(
253260
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None

0 commit comments

Comments
 (0)