Skip to content

Commit 90950a8

Browse files
authored
added azure openai support (#3349)
* added azure openai support * precommit
1 parent e790e1f commit 90950a8

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

lm_eval/models/openai_completions.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,16 @@ def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
213213
if not isinstance(outputs, list):
214214
outputs = [outputs]
215215
for out in outputs:
216-
tmp = [None] * len(out["choices"])
217-
for choices in out["choices"]:
218-
tmp[choices["index"]] = choices["message"]["content"]
216+
try:
217+
tmp = [None] * len(out["choices"])
218+
for choices in out["choices"]:
219+
tmp[choices["index"]] = choices["message"]["content"]
220+
except Exception as e:
221+
# account for cases that generation is blocked by content filter,
222+
# which is common for Azure OpenAI Service,
223+
# not sure if need to account for multiple choices
224+
eval_logger.warning(f"Could not parse generations: {e}")
225+
tmp = [""]
219226
res = res + tmp
220227
return res
221228

@@ -346,3 +353,38 @@ def _create_payload(
346353
output.pop("stop")
347354
output["temperature"] = 1
348355
return output
356+
357+
358+
@register_model("azure-openai-chat-completions")
359+
class AzureOpenaiChatCompletionsLM(OpenAIChatCompletion):
360+
def __init__(
361+
self,
362+
model: str = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
363+
base_url: str = os.getenv("AZURE_OPENAI_ENDPOINT"),
364+
api_version: str = os.getenv("AZURE_OPENAI_API_VERSION", "2025-03-01-preview"),
365+
truncate: bool = False,
366+
**kwargs,
367+
) -> None:
368+
super().__init__()
369+
try:
370+
import openai # noqa: E401
371+
except ModuleNotFoundError:
372+
raise Exception(
373+
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
374+
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
375+
)
376+
self.model = model
377+
self.base_url = f"{base_url}/openai/deployments/{model}/chat/completions?api-version={api_version}"
378+
self.truncate = truncate
379+
self.client = openai.AzureOpenAI(
380+
azure_endpoint=base_url, api_version=api_version, api_key=self.api_key
381+
)
382+
383+
@cached_property
384+
def api_key(self):
385+
key = os.environ.get("AZURE_OPENAI_API_KEY", None)
386+
if key is None:
387+
raise ValueError(
388+
"API key not found. Please set the `AZURE_OPENAI_API_KEY` environment variable."
389+
)
390+
return key

0 commit comments

Comments
 (0)