Skip to content

Commit 8b47ec1

Browse files
authored
adds openai models (#359)
1 parent d84c378 commit 8b47ec1

File tree

7 files changed

+293
-4
lines changed

7 files changed

+293
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ dev = ["lighteval[accelerate,quality,tests,multilingual]"]
9595
extended_tasks = [
9696
"langdetect", # ifeval
9797
"openai", # llm as a judge using openai models
98+
"tiktoken"
9899
]
99100
s3 = ["s3fs"]
100101
multilingual = [

src/lighteval/models/model_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ class VLLMModelConfig:
232232
temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0.
233233

234234

235+
@dataclass
236+
class OpenAIModelConfig:
237+
model: str
238+
239+
235240
@dataclass
236241
class TGIModelConfig:
237242
inference_server_address: str
@@ -308,6 +313,7 @@ def create_model_config( # noqa: C901
308313
InferenceEndpointModelConfig,
309314
DummyModelConfig,
310315
VLLMModelConfig,
316+
OpenAIModelConfig,
311317
]:
312318
"""
313319
Create a model configuration based on the provided arguments.
@@ -345,6 +351,9 @@ def create_model_config( # noqa: C901
345351
if model_args.pop("vllm", False):
346352
return VLLMModelConfig(**model_args)
347353

354+
if model_args.pop("openai", False):
355+
return OpenAIModelConfig(**model_args)
356+
348357
model_args["accelerator"] = accelerator
349358
model_args["use_chat_template"] = use_chat_template
350359
model_args["compile"] = bool(model_args["compile"]) if "compile" in model_args else False

src/lighteval/models/model_loader.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,20 @@
3535
DummyModelConfig,
3636
InferenceEndpointModelConfig,
3737
InferenceModelConfig,
38+
OpenAIModelConfig,
3839
TGIModelConfig,
3940
VLLMModelConfig,
4041
)
42+
from lighteval.models.openai_model import OpenAIClient
4143
from lighteval.models.tgi_model import ModelClient
4244
from lighteval.models.vllm_model import VLLMModel
43-
from lighteval.utils.imports import NO_TGI_ERROR_MSG, NO_VLLM_ERROR_MSG, is_tgi_available, is_vllm_available
45+
from lighteval.utils.imports import (
46+
NO_TGI_ERROR_MSG,
47+
NO_VLLM_ERROR_MSG,
48+
is_openai_available,
49+
is_tgi_available,
50+
is_vllm_available,
51+
)
4452
from lighteval.utils.utils import EnvConfig
4553

4654

@@ -53,6 +61,7 @@ def load_model( # noqa: C901
5361
InferenceEndpointModelConfig,
5462
DummyModelConfig,
5563
VLLMModelConfig,
64+
OpenAIModelConfig,
5665
],
5766
env_config: EnvConfig,
5867
) -> Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel]:
@@ -87,6 +96,9 @@ def load_model( # noqa: C901
8796
if isinstance(config, VLLMModelConfig):
8897
return load_model_with_accelerate_or_default(config=config, env_config=env_config)
8998

99+
if isinstance(config, OpenAIModelConfig):
100+
return load_openai_model(config=config, env_config=env_config)
101+
90102

91103
def load_model_with_tgi(config: TGIModelConfig):
92104
if not is_tgi_available():
@@ -99,6 +111,15 @@ def load_model_with_tgi(config: TGIModelConfig):
99111
return model
100112

101113

114+
def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig):
115+
if not is_openai_available():
116+
raise ImportError()
117+
118+
model = OpenAIClient(config, env_config)
119+
120+
return model
121+
122+
102123
def load_model_with_inference_endpoints(config: InferenceEndpointModelConfig, env_config: EnvConfig):
103124
hlog("Spin up model using inference endpoint.")
104125
model = InferenceEndpointModel(config=config, env_config=env_config)

src/lighteval/models/model_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_result_for_eval(self):
5959

6060
@dataclass
6161
class GenerativeResponse(ModelResponse):
62-
result: str = field(default_factory=str) # generated text continuation
62+
result: list[str] = field(default_factory=str) # generated text continuation
6363
logits: Optional[list[float]] = None # Generated text logits
6464

6565
def get_result_for_eval(self):
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
import os
24+
import time
25+
from concurrent.futures import ThreadPoolExecutor
26+
from typing import Optional
27+
28+
from tqdm import tqdm
29+
30+
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
31+
from lighteval.logging.hierarchical_logger import hlog_warn
32+
from lighteval.models.abstract_model import LightevalModel
33+
from lighteval.models.endpoint_model import ModelInfo
34+
from lighteval.models.model_output import (
35+
GenerativeResponse,
36+
LoglikelihoodResponse,
37+
LoglikelihoodSingleTokenResponse,
38+
)
39+
from lighteval.tasks.requests import (
40+
GreedyUntilRequest,
41+
LoglikelihoodRequest,
42+
LoglikelihoodRollingRequest,
43+
LoglikelihoodSingleTokenRequest,
44+
)
45+
from lighteval.utils.imports import is_openai_available
46+
47+
48+
if is_openai_available():
49+
import logging
50+
51+
import tiktoken
52+
from openai import OpenAI
53+
54+
logging.getLogger("openai").setLevel(logging.ERROR)
55+
logging.getLogger("httpx").setLevel(logging.ERROR)
56+
57+
58+
class OpenAIClient(LightevalModel):
59+
_DEFAULT_MAX_LENGTH: int = 4096
60+
61+
def __init__(self, config, env_config) -> None:
62+
api_key = os.environ["OPENAI_API_KEY"]
63+
self.client = OpenAI(api_key=api_key)
64+
65+
self.model_info = ModelInfo(
66+
model_name=config.model,
67+
model_sha="",
68+
model_dtype=None,
69+
model_size="",
70+
)
71+
self.API_MAX_RETRY = 5
72+
self.API_RETRY_SLEEP = 3
73+
self.API_RETRY_MULTIPLIER = 2
74+
self.CONCURENT_CALLS = 100
75+
self.model = config.model
76+
self._tokenizer = tiktoken.encoding_for_model(self.model)
77+
self.pairwise_tokenization = False
78+
79+
def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias):
80+
for _ in range(self.API_MAX_RETRY):
81+
try:
82+
response = self.client.chat.completions.create(
83+
model=self.model,
84+
messages=[{"role": "user", "content": prompt}],
85+
response_format={"type": "text"},
86+
max_tokens=max_new_tokens if max_new_tokens > 0 else None,
87+
logprobs=return_logits,
88+
logit_bias=logit_bias,
89+
n=num_samples,
90+
)
91+
return response
92+
except Exception as e:
93+
hlog_warn(f"{type(e), e}")
94+
time.sleep(self.API_RETRY_SLEEP)
95+
self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER
96+
raise Exception("Failed to get response from the API")
97+
98+
def __call_api_parallel(
99+
self,
100+
prompts,
101+
return_logits: bool | list[bool],
102+
max_new_tokens: int | list[int],
103+
num_samples: int | list[int],
104+
logit_bias: list[dict[int, float]] | None = None,
105+
):
106+
results = []
107+
108+
return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits
109+
max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens
110+
num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples
111+
logit_biass = [logit_bias for _ in prompts] if logit_bias is None else logit_bias
112+
113+
assert (
114+
len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(logit_biass)
115+
), "Length of prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass should be same"
116+
117+
with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor:
118+
for entry in tqdm(
119+
executor.map(self.__call_api, prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass),
120+
total=len(prompts),
121+
):
122+
results.append(entry)
123+
124+
if None in results:
125+
raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.")
126+
127+
return results
128+
129+
def greedy_until(
130+
self,
131+
requests: list[GreedyUntilRequest],
132+
override_bs: Optional[int] = None,
133+
) -> list[GenerativeResponse]:
134+
"""
135+
Generates responses using a greedy decoding strategy until certain ending conditions are met.
136+
137+
Args:
138+
requests (list[Request]): list of requests containing the context and ending conditions.
139+
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
140+
override_bs (int, optional): Override the batch size for generation. Defaults to None.
141+
142+
Returns:
143+
list[GenerativeResponse]: list of generated responses.
144+
"""
145+
for request in requests:
146+
request.tokenized_context = self.tok_encode(request.context)
147+
148+
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
149+
results = []
150+
151+
for _ in tqdm(
152+
dataset.splits_start_end_iterator(),
153+
total=dataset.num_dataset_splits,
154+
desc="Splits",
155+
position=0,
156+
disable=False, # self.disable_tqdm,
157+
):
158+
max_new_tokens = dataset[0].generation_size # could be none
159+
return_logits = dataset[0].use_logits
160+
num_samples = dataset[0].num_samples
161+
contexts = [c.context for c in dataset]
162+
163+
responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples)
164+
165+
for response in responses:
166+
result: list[str] = [output.message.content for output in response.choices]
167+
168+
cur_response = GenerativeResponse(
169+
result=result,
170+
logits=None,
171+
generated_tokens=[],
172+
input_tokens=[],
173+
)
174+
results.append(cur_response)
175+
176+
return dataset.get_original_order(results)
177+
178+
@property
179+
def tokenizer(self):
180+
return self._tokenizer
181+
182+
def tok_encode(self, text: str):
183+
return self.tokenizer.encode(text)
184+
185+
@property
186+
def add_special_tokens(self) -> bool:
187+
return False
188+
189+
@property
190+
def max_length(self) -> int:
191+
"""Return the maximum sequence length of the model."""
192+
return 4096
193+
194+
def loglikelihood(
195+
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
196+
) -> list[LoglikelihoodResponse]:
197+
"""Tokenize the context and continuation and compute the log likelihood of those
198+
tokenized sequences.
199+
"""
200+
for request in requests:
201+
if request.context == "":
202+
request.tokenized_context = [" "]
203+
request.tokenized_continuation = self.tok_encode(request.choice)
204+
else:
205+
# The following line is mandatory for compatibility with the harness
206+
request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair(
207+
request.context, request.choice, pairwise=self.pairwise_tokenization
208+
)
209+
return self._loglikelihood_tokens(requests)
210+
211+
def _loglikelihood_tokens(
212+
self,
213+
requests: list[LoglikelihoodRequest],
214+
) -> list[LoglikelihoodResponse]:
215+
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=1)
216+
results = []
217+
218+
for _ in tqdm(dataset.splits_start_end_iterator()):
219+
inputs = [dataset[i].context for i in range(len(dataset))]
220+
logit_biass = []
221+
max_new_tokens = [len(dataset[i].tokenized_continuation) for i in range(len(dataset))]
222+
223+
assert all(
224+
new_tokens == 1 for new_tokens in max_new_tokens
225+
), "Only single token continuations are supported when using openai API."
226+
227+
for i in range(len(dataset)):
228+
logit_bias = {tok: 100 for tok in dataset[i].tokenized_continuation}
229+
logit_biass.append(logit_bias)
230+
231+
outputs = self.__call_api_parallel(
232+
inputs, return_logits=True, max_new_tokens=max_new_tokens, num_samples=1, logit_bias=logit_biass
233+
)
234+
235+
for output, input in zip(outputs, dataset):
236+
continuation_logprobs = [content.logprob for content in output.choices[0].logprobs.content]
237+
answer = LoglikelihoodResponse(
238+
input_tokens=input.tokenized_context + input.tokenized_continuation,
239+
generated_tokens=input.tokenized_continuation,
240+
result=(sum(continuation_logprobs), None),
241+
)
242+
results.append(answer)
243+
244+
return dataset.get_original_order(results)
245+
246+
def loglikelihood_rolling(
247+
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
248+
) -> list[LoglikelihoodResponse]:
249+
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
250+
raise NotImplementedError
251+
252+
def loglikelihood_single_token(
253+
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
254+
) -> list[LoglikelihoodSingleTokenResponse]:
255+
"""Tokenize the context and continuation and compute the log likelihood of those
256+
tokenized sequences.
257+
"""
258+
raise NotImplementedError

src/lighteval/tasks/default_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8395,7 +8395,7 @@
83958395
few_shots_split=None,
83968396
few_shots_select=None,
83978397
generation_size=1,
8398-
metric=[Metrics.loglikelihood_acc_single_token, "mcc_single_token"],
8398+
metric=[Metrics.loglikelihood_acc_single_token, Metrics.mcc_single_token],
83998399
stop_sequence=["\n"],
84008400
output_regex=None,
84018401
frozen=False,

src/lighteval/tasks/prompt_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _single_turn_context(
189189
system_prompt=system_prompt,
190190
use_chat_template=use_chat_template,
191191
)
192-
toks = self.model.tokenizer(output)["input_ids"]
192+
toks = self.model.tok_encode(output)
193193

194194
# If we need to truncate few-shots to fit in the context
195195
if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None:

0 commit comments

Comments
 (0)