26
26
27
27
import yaml
28
28
from huggingface_hub import AsyncInferenceClient , ChatCompletionOutput
29
+ from huggingface_hub .errors import HfHubHTTPError
29
30
from pydantic import NonNegativeInt
30
31
from tqdm import tqdm
31
32
from tqdm .asyncio import tqdm as async_tqdm
@@ -60,13 +61,15 @@ class InferenceProvidersModelConfig(ModelConfig):
60
61
provider: Name of the inference provider
61
62
timeout: Request timeout in seconds
62
63
proxies: Proxy configuration for requests
64
+ org_to_bill: Organisation to bill if not the user
63
65
generation_parameters: Parameters for text generation
64
66
"""
65
67
66
68
model_name : str
67
69
provider : str
68
70
timeout : int | None = None
69
71
proxies : Any | None = None
72
+ org_to_bill : str | None = None
70
73
parallel_calls_count : NonNegativeInt = 10
71
74
72
75
@classmethod
@@ -78,12 +81,14 @@ def from_path(cls, path):
78
81
provider = config .get ("provider" , None )
79
82
timeout = config .get ("timeout" , None )
80
83
proxies = config .get ("proxies" , None )
84
+ org_to_bill = config .get ("org_to_bill" , None )
81
85
generation_parameters = GenerationParameters .from_dict (config )
82
86
return cls (
83
87
model = model_name ,
84
88
provider = provider ,
85
89
timeout = timeout ,
86
90
proxies = proxies ,
91
+ org_to_bill = org_to_bill ,
87
92
generation_parameters = generation_parameters ,
88
93
)
89
94
@@ -121,19 +126,20 @@ def __init__(self, config: InferenceProvidersModelConfig) -> None:
121
126
provider = self .provider ,
122
127
timeout = config .timeout ,
123
128
proxies = config .proxies ,
129
+ bill_to = config .org_to_bill ,
124
130
)
125
- self ._tokenizer = AutoTokenizer .from_pretrained (self .model_name )
131
+ try :
132
+ self ._tokenizer = AutoTokenizer .from_pretrained (self .model_name )
133
+ except HfHubHTTPError :
134
+ logger .warning ("Could not load model's tokenizer: {e}." )
135
+ self ._tokenizer = None
126
136
127
137
def _encode (self , text : str ) -> dict :
128
- enc = self ._tokenizer (text = text )
129
- return enc
130
-
131
- def tok_encode (self , text : str | list [str ]):
132
- if isinstance (text , list ):
133
- toks = [self ._encode (t ["content" ]) for t in text ]
134
- toks = [tok for tok in toks if tok ]
135
- return toks
136
- return self ._encode (text )
138
+ if self ._tokenizer :
139
+ enc = self ._tokenizer (text = text )
140
+ return enc
141
+ logger .warning ("Tokenizer is not loaded, can't encore the text, returning it as such." )
142
+ return text
137
143
138
144
async def __call_api (self , prompt : List [dict ], num_samples : int ) -> Optional [ChatCompletionOutput ]:
139
145
"""Make API call with exponential backoff retry logic.
@@ -204,9 +210,6 @@ def greedy_until(
204
210
Returns:
205
211
list[GenerativeResponse]: list of generated responses.
206
212
"""
207
- for request in requests :
208
- request .tokenized_context = self .tok_encode (request .context )
209
-
210
213
dataset = GenerativeTaskDataset (requests = requests , num_dataset_splits = self .DATASET_SPLITS )
211
214
results = []
212
215
@@ -247,7 +250,11 @@ def add_special_tokens(self) -> bool:
247
250
@property
248
251
def max_length (self ) -> int :
249
252
"""Return the maximum sequence length of the model."""
250
- return self ._tokenizer .model_max_length
253
+ try :
254
+ return self ._tokenizer .model_max_length
255
+ except AttributeError :
256
+ logger .warning ("Tokenizer was not correctly loaded. Max model context length is assumed to be 30K tokens" )
257
+ return 30000
251
258
252
259
def loglikelihood (
253
260
self , requests : list [LoglikelihoodRequest ], override_bs : Optional [int ] = None
0 commit comments