-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Allow proper type on AnthropicProvider
when using Bedrock
#2490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b856c3a
20a8451
aa939f1
c10c059
546e824
9b84bcd
0fb8f3a
cc388ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,26 +1,29 @@ | ||||||
from __future__ import annotations as _annotations | ||||||
|
||||||
import os | ||||||
from typing import overload | ||||||
from typing import Union, overload | ||||||
|
||||||
import httpx | ||||||
|
||||||
from pydantic_ai.exceptions import UserError | ||||||
from pydantic_ai.models import cached_async_http_client | ||||||
from pydantic_ai.profiles import ModelProfile | ||||||
from pydantic_ai.profiles.anthropic import anthropic_model_profile | ||||||
from pydantic_ai.providers import Provider | ||||||
|
||||||
try: | ||||||
from anthropic import AsyncAnthropic | ||||||
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock | ||||||
except ImportError as _import_error: # pragma: no cover | ||||||
raise ImportError( | ||||||
'Please install the `anthropic` package to use the Anthropic provider, ' | ||||||
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`' | ||||||
) from _import_error | ||||||
|
||||||
from typing_extensions import TypeAlias | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move this import up (before the try..except). |
||||||
|
||||||
class AnthropicProvider(Provider[AsyncAnthropic]): | ||||||
ASYNC_ANTHROPIC_CLIENT: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it really union what we want?
Suggested change
Does this work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use full uppercase for type aliases, that's for constants. 🙏 |
||||||
|
||||||
|
||||||
class AnthropicProvider(Provider[ASYNC_ANTHROPIC_CLIENT]): | ||||||
"""Provider for Anthropic API.""" | ||||||
|
||||||
@property | ||||||
|
@@ -32,30 +35,52 @@ def base_url(self) -> str: | |||||
return str(self._client.base_url) | ||||||
|
||||||
@property | ||||||
def client(self) -> AsyncAnthropic: | ||||||
def client(self) -> ASYNC_ANTHROPIC_CLIENT: | ||||||
return self._client | ||||||
|
||||||
def model_profile(self, model_name: str) -> ModelProfile | None: | ||||||
return anthropic_model_profile(model_name) | ||||||
|
||||||
@overload | ||||||
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ... | ||||||
def __init__(self, *, anthropic_client: ASYNC_ANTHROPIC_CLIENT | None = None) -> None: ... | ||||||
|
||||||
@overload | ||||||
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ... | ||||||
|
||||||
@overload | ||||||
def __init__( | ||||||
self, | ||||||
*, | ||||||
aws_secret_key: str | None = None, | ||||||
aws_access_key: str | None = None, | ||||||
aws_region: str | None = None, | ||||||
aws_profile: str | None = None, | ||||||
aws_session_token: str | None = None, | ||||||
Comment on lines
+54
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to not complicate things here. The user can just pass the client when using Bedrock. |
||||||
http_client: httpx.AsyncClient | None = None, | ||||||
) -> None: ... | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
*, | ||||||
api_key: str | None = None, | ||||||
anthropic_client: AsyncAnthropic | None = None, | ||||||
aws_secret_key: str | None = None, | ||||||
aws_access_key: str | None = None, | ||||||
aws_region: str | None = None, | ||||||
aws_profile: str | None = None, | ||||||
aws_session_token: str | None = None, | ||||||
anthropic_client: ASYNC_ANTHROPIC_CLIENT | None = None, | ||||||
http_client: httpx.AsyncClient | None = None, | ||||||
) -> None: | ||||||
"""Create a new Anthropic provider. | ||||||
|
||||||
Args: | ||||||
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable | ||||||
will be used if available. | ||||||
aws_secret_key: AWS secret access key for Bedrock authentication. | ||||||
aws_access_key: AWS access key ID for Bedrock authentication. | ||||||
aws_region: AWS region for Bedrock service. | ||||||
aws_profile: AWS profile name for Bedrock authentication. | ||||||
aws_session_token: AWS session token for temporary credentials. | ||||||
anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python) | ||||||
client to use. If provided, the `api_key` and `http_client` arguments will be ignored. | ||||||
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. | ||||||
|
@@ -66,14 +91,30 @@ def __init__( | |||||
self._client = anthropic_client | ||||||
else: | ||||||
api_key = api_key or os.environ.get('ANTHROPIC_API_KEY') | ||||||
if not api_key: | ||||||
raise UserError( | ||||||
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`' | ||||||
'to use the Anthropic provider.' | ||||||
) | ||||||
|
||||||
if http_client is not None: | ||||||
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) | ||||||
if api_key is None: | ||||||
if http_client is not None: | ||||||
self._client = AsyncAnthropicBedrock( | ||||||
aws_access_key=aws_access_key, | ||||||
aws_secret_key=aws_secret_key, | ||||||
aws_session_token=aws_session_token, | ||||||
aws_profile=aws_profile, | ||||||
aws_region=aws_region, | ||||||
http_client=http_client, | ||||||
) | ||||||
else: | ||||||
http_client = cached_async_http_client(provider='anthropic') | ||||||
self._client = AsyncAnthropicBedrock( | ||||||
aws_access_key=aws_access_key, | ||||||
aws_secret_key=aws_secret_key, | ||||||
aws_session_token=aws_session_token, | ||||||
aws_profile=aws_profile, | ||||||
aws_region=aws_region, | ||||||
http_client=http_client, | ||||||
) | ||||||
|
||||||
else: | ||||||
http_client = cached_async_http_client(provider='anthropic') | ||||||
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) | ||||||
if http_client is not None: | ||||||
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) | ||||||
else: | ||||||
http_client = cached_async_http_client(provider='anthropic') | ||||||
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've applied this in another PR.