-
Notifications
You must be signed in to change notification settings - Fork 0
add auth parameter and AzureAuth #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import os | ||
import inspect | ||
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload | ||
from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload | ||
from typing_extensions import Self, override | ||
|
||
import httpx | ||
|
@@ -42,6 +42,107 @@ | |
API_KEY_SENTINEL = "".join(["<", "missing API key", ">"]) | ||
|
||
|
||
if TYPE_CHECKING: | ||
from azure.core.credentials import TokenCredential | ||
from azure.core.credentials_async import AsyncTokenCredential | ||
|
||
|
||
AzureTokenProvider = Callable[[], str] | ||
AsyncAzureTokenProvider = Callable[[], Awaitable[str]] | ||
|
||
|
||
class AzureAuth: | ||
@overload | ||
def __init__(self, *, token_provider: AzureTokenProvider) -> None: ... | ||
|
||
@overload | ||
def __init__( | ||
self, | ||
*, | ||
credential: TokenCredential, | ||
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], | ||
) -> None: ... | ||
|
||
def __init__( | ||
self, | ||
*, | ||
token_provider: AzureTokenProvider | None = None, | ||
credential: TokenCredential | None = None, | ||
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick, parameters of type |
||
) -> None: | ||
if token_provider is not None and credential is not None: | ||
raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.") | ||
if token_provider is None and credential is None: | ||
raise ValueError("One of `token_provider` or `credential` must be provided to AzureAuth.") | ||
self.token_provider = token_provider | ||
self.credential = credential | ||
self.scopes = scopes | ||
|
||
def get_token(self) -> str: | ||
if self.token_provider is not None: | ||
token = self.token_provider() | ||
return token | ||
|
||
if self.credential is not None: | ||
try: | ||
from azure.identity import get_bearer_token_provider | ||
except ImportError as err: | ||
raise ImportError( | ||
"azure-identity library is not installed. Please install it to use AzureAuth." | ||
) from err | ||
token_provider = get_bearer_token_provider(self.credential, *self.scopes) | ||
|
||
token = token_provider() | ||
return token | ||
|
||
raise ValueError("Unexpected values provided to AzureAuth. Unable to get token.") | ||
|
||
|
||
class AsyncAzureAuth: | ||
@overload | ||
def __init__(self, *, token_provider: AsyncAzureTokenProvider) -> None: ... | ||
|
||
@overload | ||
def __init__( | ||
self, | ||
*, | ||
credential: AsyncTokenCredential, | ||
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], | ||
) -> None: ... | ||
|
||
def __init__( | ||
self, | ||
*, | ||
token_provider: AsyncAzureTokenProvider | None = None, | ||
credential: AsyncTokenCredential | None = None, | ||
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], | ||
) -> None: | ||
if token_provider is not None and credential is not None: | ||
raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.") | ||
if token_provider is None and credential is None: | ||
raise ValueError("One of `token_provider` or `credential` must be provided to AsyncAzureAuth.") | ||
self.token_provider = token_provider | ||
self.credential = credential | ||
self.scopes = scopes | ||
|
||
async def get_token(self) -> str: | ||
if self.token_provider is not None: | ||
token = await self.token_provider() | ||
return token | ||
|
||
if self.credential is not None: | ||
|
||
try: | ||
from azure.identity.aio import get_bearer_token_provider | ||
except ImportError as err: | ||
raise ImportError( | ||
"azure-identity library is not installed. Please install it to use AsyncAzureAuth." | ||
) from err | ||
token_provider = get_bearer_token_provider(self.credential, *self.scopes) | ||
token = await token_provider() | ||
return token | ||
|
||
raise ValueError("Unexpected values provided to AsyncAzureAuth. Unable to get token.") | ||
|
||
|
||
class MutuallyExclusiveAuthError(OpenAIError): | ||
def __init__(self) -> None: | ||
super().__init__( | ||
|
@@ -255,7 +356,7 @@ def __init__( | |
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None | ||
|
||
@override | ||
def copy( | ||
def copy( # type: ignore | ||
self, | ||
*, | ||
api_key: str | None = None, | ||
|
@@ -301,7 +402,7 @@ def copy( | |
}, | ||
) | ||
|
||
with_options = copy | ||
with_options = copy # type: ignore | ||
|
||
def _get_azure_ad_token(self) -> str | None: | ||
if self._azure_ad_token is not None: | ||
|
@@ -536,7 +637,7 @@ def __init__( | |
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None | ||
|
||
@override | ||
def copy( | ||
def copy( # type: ignore | ||
self, | ||
*, | ||
api_key: str | None = None, | ||
|
@@ -582,7 +683,7 @@ def copy( | |
}, | ||
) | ||
|
||
with_options = copy | ||
with_options = copy # type: ignore | ||
|
||
async def _get_azure_ad_token(self) -> str | None: | ||
if self._azure_ad_token is not None: | ||
|
Uh oh!
There was an error while loading. Please reload this page.