2626
2727import yaml
2828from huggingface_hub import AsyncInferenceClient , ChatCompletionOutput
29+ from huggingface_hub .errors import HfHubHTTPError
2930from pydantic import NonNegativeInt
3031from tqdm import tqdm
3132from tqdm .asyncio import tqdm as async_tqdm
@@ -60,13 +61,15 @@ class InferenceProvidersModelConfig(ModelConfig):
6061 provider: Name of the inference provider
6162 timeout: Request timeout in seconds
6263 proxies: Proxy configuration for requests
64+ org_to_bill: Organisation to bill if not the user
6365 generation_parameters: Parameters for text generation
6466 """
6567
6668 model_name : str
6769 provider : str
6870 timeout : int | None = None
6971 proxies : Any | None = None
72+ org_to_bill : str | None = None
7073 parallel_calls_count : NonNegativeInt = 10
7174
7275 @classmethod
@@ -78,12 +81,14 @@ def from_path(cls, path):
7881 provider = config .get ("provider" , None )
7982 timeout = config .get ("timeout" , None )
8083 proxies = config .get ("proxies" , None )
84+ org_to_bill = config .get ("org_to_bill" , None )
8185 generation_parameters = GenerationParameters .from_dict (config )
8286 return cls (
8387 model = model_name ,
8488 provider = provider ,
8589 timeout = timeout ,
8690 proxies = proxies ,
91+ org_to_bill = org_to_bill ,
8792 generation_parameters = generation_parameters ,
8893 )
8994
@@ -121,19 +126,20 @@ def __init__(self, config: InferenceProvidersModelConfig) -> None:
121126 provider = self .provider ,
122127 timeout = config .timeout ,
123128 proxies = config .proxies ,
129+ bill_to = config .org_to_bill ,
124130 )
125- self ._tokenizer = AutoTokenizer .from_pretrained (self .model_name )
131+ try :
132+ self ._tokenizer = AutoTokenizer .from_pretrained (self .model_name )
133+ except HfHubHTTPError :
134+ logger .warning ("Could not load model's tokenizer: {e}." )
135+ self ._tokenizer = None
126136
127137 def _encode (self , text : str ) -> dict :
128- enc = self ._tokenizer (text = text )
129- return enc
130-
131- def tok_encode (self , text : str | list [str ]):
132- if isinstance (text , list ):
133- toks = [self ._encode (t ["content" ]) for t in text ]
134- toks = [tok for tok in toks if tok ]
135- return toks
136- return self ._encode (text )
138+ if self ._tokenizer :
139+ enc = self ._tokenizer (text = text )
140+ return enc
141+ logger .warning ("Tokenizer is not loaded, can't encore the text, returning it as such." )
142+ return text
137143
138144 async def __call_api (self , prompt : List [dict ], num_samples : int ) -> Optional [ChatCompletionOutput ]:
139145 """Make API call with exponential backoff retry logic.
@@ -204,9 +210,6 @@ def greedy_until(
204210 Returns:
205211 list[GenerativeResponse]: list of generated responses.
206212 """
207- for request in requests :
208- request .tokenized_context = self .tok_encode (request .context )
209-
210213 dataset = GenerativeTaskDataset (requests = requests , num_dataset_splits = self .DATASET_SPLITS )
211214 results = []
212215
@@ -247,7 +250,11 @@ def add_special_tokens(self) -> bool:
247250 @property
248251 def max_length (self ) -> int :
249252 """Return the maximum sequence length of the model."""
250- return self ._tokenizer .model_max_length
253+ try :
254+ return self ._tokenizer .model_max_length
255+ except AttributeError :
256+ logger .warning ("Tokenizer was not correctly loaded. Max model context length is assumed to be 30K tokens" )
257+ return 30000
251258
252259 def loglikelihood (
253260 self , requests : list [LoglikelihoodRequest ], override_bs : Optional [int ] = None
0 commit comments