@@ -213,9 +213,16 @@ def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
213213 if not isinstance (outputs , list ):
214214 outputs = [outputs ]
215215 for out in outputs :
216- tmp = [None ] * len (out ["choices" ])
217- for choices in out ["choices" ]:
218- tmp [choices ["index" ]] = choices ["message" ]["content" ]
216+ try :
217+ tmp = [None ] * len (out ["choices" ])
218+ for choices in out ["choices" ]:
219+ tmp [choices ["index" ]] = choices ["message" ]["content" ]
220+ except Exception as e :
221+ # account for cases that generation is blocked by content filter,
222+ # which is common for Azure OpenAI Service,
223+ # not sure if need to account for multiple choices
224+ eval_logger .warning (f"Could not parse generations: { e } " )
225+ tmp = ["" ]
219226 res = res + tmp
220227 return res
221228
@@ -346,3 +353,38 @@ def _create_payload(
346353 output .pop ("stop" )
347354 output ["temperature" ] = 1
348355 return output
356+
357+
358+ @register_model ("azure-openai-chat-completions" )
359+ class AzureOpenaiChatCompletionsLM (OpenAIChatCompletion ):
360+ def __init__ (
361+ self ,
362+ model : str = os .getenv ("AZURE_OPENAI_DEPLOYMENT_NAME" ),
363+ base_url : str = os .getenv ("AZURE_OPENAI_ENDPOINT" ),
364+ api_version : str = os .getenv ("AZURE_OPENAI_API_VERSION" , "2025-03-01-preview" ),
365+ truncate : bool = False ,
366+ ** kwargs ,
367+ ) -> None :
368+ super ().__init__ ()
369+ try :
370+ import openai # noqa: E401
371+ except ModuleNotFoundError :
372+ raise Exception (
373+ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
374+ please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`" ,
375+ )
376+ self .model = model
377+ self .base_url = f"{ base_url } /openai/deployments/{ model } /chat/completions?api-version={ api_version } "
378+ self .truncate = truncate
379+ self .client = openai .AzureOpenAI (
380+ azure_endpoint = base_url , api_version = api_version , api_key = self .api_key
381+ )
382+
383+ @cached_property
384+ def api_key (self ):
385+ key = os .environ .get ("AZURE_OPENAI_API_KEY" , None )
386+ if key is None :
387+ raise ValueError (
388+ "API key not found. Please set the `AZURE_OPENAI_API_KEY` environment variable."
389+ )
390+ return key
0 commit comments