1
1
# ---------------------------------------------------------
2
2
# Copyright (c) Microsoft Corporation. All rights reserved.
3
3
# ---------------------------------------------------------
4
- from azure .ai .evaluation ._model_configurations import AzureOpenAIModelConfiguration , OpenAIModelConfiguration
4
+ from typing import TYPE_CHECKING , Any , Dict , Optional , Union
5
+
6
+ from typing_extensions import TypeIs
5
7
6
- from azure .ai .evaluation ._constants import DEFAULT_AOAI_API_VERSION
8
+ from azure .ai .evaluation ._common ._experimental import experimental
9
+ from azure .ai .evaluation ._constants import DEFAULT_AOAI_API_VERSION , TokenScope
7
10
from azure .ai .evaluation ._exceptions import ErrorBlame , ErrorCategory , ErrorTarget , EvaluationException
11
+ from azure .ai .evaluation ._model_configurations import AzureOpenAIModelConfiguration , OpenAIModelConfiguration
8
12
from azure .ai .evaluation ._user_agent import UserAgentSingleton
9
- from typing import Any , Dict , Union
10
- from azure .ai .evaluation ._common ._experimental import experimental
13
+ from azure .core .credentials import TokenCredential
14
+
15
+ if TYPE_CHECKING :
16
+ from openai .lib .azure import AzureADTokenProvider
11
17
12
18
13
19
@experimental
@@ -30,6 +36,8 @@ class AzureOpenAIGrader:
30
36
to be formatted as a dictionary that matches the specifications of the sub-types of
31
37
the TestingCriterion alias specified in (OpenAI's SDK)[https://github.com/openai/openai-python/blob/ed53107e10e6c86754866b48f8bd862659134ca8/src/openai/types/eval_create_params.py#L151].
32
38
:type grader_config: Dict[str, Any]
39
+ :param credential: The credential to use to authenticate to the model. Only applicable to AzureOpenAI models.
40
+ :type credential: ~azure.core.credentials.TokenCredential
33
41
:param kwargs: Additional keyword arguments to pass to the grader.
34
42
:type kwargs: Any
35
43
@@ -43,31 +51,52 @@ def __init__(
43
51
* ,
44
52
model_config : Union [AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
45
53
grader_config : Dict [str , Any ],
54
+ credential : Optional [TokenCredential ] = None ,
46
55
** kwargs : Any ,
47
56
):
48
57
self ._model_config = model_config
49
58
self ._grader_config = grader_config
59
+ self ._credential = credential
50
60
51
61
if kwargs .get ("validate" , True ):
52
62
self ._validate_model_config ()
53
63
self ._validate_grader_config ()
54
64
55
65
def _validate_model_config (self ) -> None :
56
66
"""Validate the model configuration that this grader wrapper is using."""
57
- if "api_key" not in self ._model_config or not self ._model_config .get ("api_key" ):
58
- msg = f"{ type (self ).__name__ } : Requires an api_key in the supplied model_config."
59
- raise EvaluationException (
60
- message = msg ,
61
- blame = ErrorBlame .USER_ERROR ,
62
- category = ErrorCategory .INVALID_VALUE ,
63
- target = ErrorTarget .AOAI_GRADER ,
64
- )
67
+ msg = None
68
+ if self ._is_azure_model_config (self ._model_config ):
69
+ if not any (auth for auth in (self ._model_config .get ("api_key" ), self ._credential )):
70
+ msg = (
71
+ f"{ type (self ).__name__ } : Requires an api_key in the supplied model_config, "
72
+ + "or providing a credential to the grader's __init__ method. "
73
+ )
74
+
75
+ else :
76
+ if "api_key" not in self ._model_config or not self ._model_config .get ("api_key" ):
77
+ msg = f"{ type (self ).__name__ } : Requires an api_key in the supplied model_config."
78
+
79
+ if msg is None :
80
+ return
81
+
82
+ raise EvaluationException (
83
+ message = msg ,
84
+ blame = ErrorBlame .USER_ERROR ,
85
+ category = ErrorCategory .INVALID_VALUE ,
86
+ target = ErrorTarget .AOAI_GRADER ,
87
+ )
65
88
66
89
def _validate_grader_config (self ) -> None :
67
90
"""Validate the grader configuration that this grader wrapper is using."""
68
91
69
92
return
70
93
94
+ @staticmethod
95
+ def _is_azure_model_config (
96
+ model_config : Union [AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
97
+ ) -> TypeIs [AzureOpenAIModelConfiguration ]:
98
+ return "azure_endpoint" in model_config
99
+
71
100
def get_client (self ) -> Any :
72
101
"""Construct an appropriate OpenAI client using this grader's model configuration.
73
102
Returns a slightly different client depending on whether or not this grader's model
@@ -77,23 +106,38 @@ def get_client(self) -> Any:
77
106
:rtype: [~openai.OpenAI, ~openai.AzureOpenAI]
78
107
"""
79
108
default_headers = {"User-Agent" : UserAgentSingleton ().value }
80
- if "azure_endpoint" in self ._model_config :
109
+ model_config : Union [AzureOpenAIModelConfiguration , OpenAIModelConfiguration ] = self ._model_config
110
+ api_key : Optional [str ] = model_config .get ("api_key" )
111
+
112
+ if self ._is_azure_model_config (model_config ):
81
113
from openai import AzureOpenAI
82
114
83
115
# TODO set default values?
84
116
return AzureOpenAI (
85
- azure_endpoint = self . _model_config ["azure_endpoint" ],
86
- api_key = self . _model_config . get ( " api_key" , None ) , # Default-style access to appease linters.
117
+ azure_endpoint = model_config ["azure_endpoint" ],
118
+ api_key = api_key , # Default-style access to appease linters.
87
119
api_version = DEFAULT_AOAI_API_VERSION , # Force a known working version
88
- azure_deployment = self ._model_config .get ("azure_deployment" , "" ),
120
+ azure_deployment = model_config .get ("azure_deployment" , "" ),
121
+ azure_ad_token_provider = self .get_token_provider (self ._credential ) if not api_key else None ,
89
122
default_headers = default_headers ,
90
123
)
91
124
from openai import OpenAI
92
125
93
126
# TODO add default values for base_url and organization?
94
127
return OpenAI (
95
- api_key = self . _model_config [ " api_key" ] ,
96
- base_url = self . _model_config .get ("base_url" , "" ),
97
- organization = self . _model_config .get ("organization" , "" ),
128
+ api_key = api_key ,
129
+ base_url = model_config .get ("base_url" , "" ),
130
+ organization = model_config .get ("organization" , "" ),
98
131
default_headers = default_headers ,
99
132
)
133
+
134
+ @staticmethod
135
+ def get_token_provider (cred : TokenCredential ) -> "AzureADTokenProvider" :
136
+ """Get the token provider the AzureOpenAI client.
137
+
138
+ :param TokenCredential cred: The Azure authentication credential.
139
+ :return: The token provider if a credential is provided, otherwise None.
140
+ :rtype: openai.lib.azure.AzureADTokenProvider
141
+ """
142
+
143
+ return lambda : cred .get_token (TokenScope .COGNITIVE_SERVICES_MANAGEMENT ).token
0 commit comments