11# ---------------------------------------------------------
22# Copyright (c) Microsoft Corporation. All rights reserved.
33# ---------------------------------------------------------
4- from azure .ai .evaluation ._model_configurations import AzureOpenAIModelConfiguration , OpenAIModelConfiguration
4+ from typing import TYPE_CHECKING , Any , Dict , Optional , Union
5+
6+ from typing_extensions import TypeIs
57
6- from azure .ai .evaluation ._constants import DEFAULT_AOAI_API_VERSION
8+ from azure .ai .evaluation ._common ._experimental import experimental
9+ from azure .ai .evaluation ._constants import DEFAULT_AOAI_API_VERSION , TokenScope
710from azure .ai .evaluation ._exceptions import ErrorBlame , ErrorCategory , ErrorTarget , EvaluationException
11+ from azure .ai .evaluation ._model_configurations import AzureOpenAIModelConfiguration , OpenAIModelConfiguration
812from azure .ai .evaluation ._user_agent import UserAgentSingleton
9- from typing import Any , Dict , Union
10- from azure .ai .evaluation ._common ._experimental import experimental
13+ from azure .core .credentials import TokenCredential
14+
15+ if TYPE_CHECKING :
16+ from openai .lib .azure import AzureADTokenProvider
1117
1218
1319@experimental
@@ -30,6 +36,8 @@ class AzureOpenAIGrader:
3036 to be formatted as a dictionary that matches the specifications of the sub-types of
3137 the TestingCriterion alias specified in (OpenAI's SDK)[https://github.com/openai/openai-python/blob/ed53107e10e6c86754866b48f8bd862659134ca8/src/openai/types/eval_create_params.py#L151].
3238 :type grader_config: Dict[str, Any]
39+ :param credential: The credential to use to authenticate to the model. Only applicable to AzureOpenAI models.
40+ :type credential: ~azure.core.credentials.TokenCredential
3341 :param kwargs: Additional keyword arguments to pass to the grader.
3442 :type kwargs: Any
3543
@@ -43,31 +51,52 @@ def __init__(
4351 * ,
4452 model_config : Union [AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
4553 grader_config : Dict [str , Any ],
54+ credential : Optional [TokenCredential ] = None ,
4655 ** kwargs : Any ,
4756 ):
4857 self ._model_config = model_config
4958 self ._grader_config = grader_config
59+ self ._credential = credential
5060
5161 if kwargs .get ("validate" , True ):
5262 self ._validate_model_config ()
5363 self ._validate_grader_config ()
5464
5565 def _validate_model_config (self ) -> None :
5666 """Validate the model configuration that this grader wrapper is using."""
57- if "api_key" not in self ._model_config or not self ._model_config .get ("api_key" ):
58- msg = f"{ type (self ).__name__ } : Requires an api_key in the supplied model_config."
59- raise EvaluationException (
60- message = msg ,
61- blame = ErrorBlame .USER_ERROR ,
62- category = ErrorCategory .INVALID_VALUE ,
63- target = ErrorTarget .AOAI_GRADER ,
64- )
67+ msg = None
68+ if self ._is_azure_model_config (self ._model_config ):
69+ if not any (auth for auth in (self ._model_config .get ("api_key" ), self ._credential )):
70+ msg = (
71+ f"{ type (self ).__name__ } : Requires an api_key in the supplied model_config, "
72+ + "or providing a credential to the grader's __init__ method. "
73+ )
74+
75+ else :
76+ if "api_key" not in self ._model_config or not self ._model_config .get ("api_key" ):
77+ msg = f"{ type (self ).__name__ } : Requires an api_key in the supplied model_config."
78+
79+ if msg is None :
80+ return
81+
82+ raise EvaluationException (
83+ message = msg ,
84+ blame = ErrorBlame .USER_ERROR ,
85+ category = ErrorCategory .INVALID_VALUE ,
86+ target = ErrorTarget .AOAI_GRADER ,
87+ )
6588
6689 def _validate_grader_config (self ) -> None :
6790 """Validate the grader configuration that this grader wrapper is using."""
6891
6992 return
7093
94+ @staticmethod
95+ def _is_azure_model_config (
96+ model_config : Union [AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
97+ ) -> TypeIs [AzureOpenAIModelConfiguration ]:
98+ return "azure_endpoint" in model_config
99+
71100 def get_client (self ) -> Any :
72101 """Construct an appropriate OpenAI client using this grader's model configuration.
73102 Returns a slightly different client depending on whether or not this grader's model
@@ -77,23 +106,38 @@ def get_client(self) -> Any:
77106 :rtype: [~openai.OpenAI, ~openai.AzureOpenAI]
78107 """
79108 default_headers = {"User-Agent" : UserAgentSingleton ().value }
80- if "azure_endpoint" in self ._model_config :
109+ model_config : Union [AzureOpenAIModelConfiguration , OpenAIModelConfiguration ] = self ._model_config
110+ api_key : Optional [str ] = model_config .get ("api_key" )
111+
112+ if self ._is_azure_model_config (model_config ):
81113 from openai import AzureOpenAI
82114
83115 # TODO set default values?
84116 return AzureOpenAI (
85- azure_endpoint = self . _model_config ["azure_endpoint" ],
86- api_key = self . _model_config . get ( " api_key" , None ) , # Default-style access to appease linters.
117+ azure_endpoint = model_config ["azure_endpoint" ],
118+ api_key = api_key , # Default-style access to appease linters.
87119 api_version = DEFAULT_AOAI_API_VERSION , # Force a known working version
88- azure_deployment = self ._model_config .get ("azure_deployment" , "" ),
120+ azure_deployment = model_config .get ("azure_deployment" , "" ),
121+ azure_ad_token_provider = self .get_token_provider (self ._credential ) if not api_key else None ,
89122 default_headers = default_headers ,
90123 )
91124 from openai import OpenAI
92125
93126 # TODO add default values for base_url and organization?
94127 return OpenAI (
95- api_key = self . _model_config [ " api_key" ] ,
96- base_url = self . _model_config .get ("base_url" , "" ),
97- organization = self . _model_config .get ("organization" , "" ),
128+ api_key = api_key ,
129+ base_url = model_config .get ("base_url" , "" ),
130+ organization = model_config .get ("organization" , "" ),
98131 default_headers = default_headers ,
99132 )
133+
134+ @staticmethod
135+ def get_token_provider (cred : TokenCredential ) -> "AzureADTokenProvider" :
136+ """Get the token provider the AzureOpenAI client.
137+
138+ :param TokenCredential cred: The Azure authentication credential.
139+ :return: The token provider if a credential is provided, otherwise None.
140+ :rtype: openai.lib.azure.AzureADTokenProvider
141+ """
142+
143+ return lambda : cred .get_token (TokenScope .COGNITIVE_SERVICES_MANAGEMENT ).token
0 commit comments