|
5 | 5 | from functools import cached_property, lru_cache |
6 | 6 | from pathlib import Path |
7 | 7 |
|
8 | | -from anthropic import Anthropic |
| 8 | +from anthropic import Anthropic, AnthropicBedrock |
9 | 9 | from anthropic.types import Message, MessageParam, TextBlockParam |
10 | 10 | from openai.types.chat import ( |
11 | 11 | ChatCompletion, |
|
24 | 24 | from pydantic_ai.messages import ModelMessage, ModelResponse |
25 | 25 | from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse |
26 | 26 | from pydantic_ai.models.anthropic import AnthropicModel |
| 27 | +from pydantic_ai.models.bedrock import BedrockConverseModel |
27 | 28 | from pydantic_ai.settings import ModelSettings |
28 | 29 | from pydantic_ai.usage import Usage |
29 | 30 | from typing_extensions import AsyncIterator, Dict, Iterable, List, Optional, Union |
30 | | - |
| 31 | +import boto3 |
31 | 32 | from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven |
32 | 33 |
|
33 | 34 |
|
@@ -78,21 +79,29 @@ class AnthropicLlmClient(LlmClient): |
78 | 79 | __definitely_allowed_models = {"claude-2.0", "claude-2.1", "claude-instant-1.2"} |
79 | 80 | __100k_models = {"claude-2.0", "claude-instant-1.2"} |
80 | 81 |
|
81 | | - def __init__(self, api_key: str): |
| 82 | + def __init__(self, api_key: Optional[str] = None, is_aws: bool = False): |
82 | 83 | self.__api_key = api_key |
| 84 | + self.__is_aws = is_aws |
| 85 | + if self.__api_key is None and not is_aws: |
| 86 | + raise ValueError("api_key is required if is_aws is False") |
83 | 87 |
|
84 | 88 | @cached_property |
85 | 89 | def __client(self): |
86 | | - return Anthropic(api_key=self.__api_key) |
| 90 | + if not self.__is_aws: |
| 91 | + return Anthropic(api_key=self.__api_key) |
| 92 | + else: |
| 93 | + return AnthropicBedrock() |
87 | 94 |
|
88 | 95 | def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model: |
89 | 96 | if model_settings is None: |
90 | 97 | raise ValueError("Model settings cannot be None") |
91 | 98 | model_name = model_settings.get("model") |
92 | 99 | if model_name is None: |
93 | 100 | raise ValueError("Model must be set cannot be None") |
94 | | - |
95 | | - return AnthropicModel(model_name, api_key=self.__api_key) |
| 101 | + if not self.__is_aws: |
| 102 | + return AnthropicModel(model_name, api_key=self.__api_key) |
| 103 | + else: |
| 104 | + return BedrockConverseModel(model_name) |
96 | 105 |
|
97 | 106 | async def request( |
98 | 107 | self, |
@@ -247,10 +256,23 @@ def __adapt_chat_completion_request( |
247 | 256 |
|
248 | 257 | @lru_cache(maxsize=None) |
249 | 258 | def get_models(self) -> set[str]: |
250 | | - return self.__definitely_allowed_models.union(set(f"{self.__allowed_model_prefix}*")) |
| 259 | + rv = set() |
| 260 | + if not self.__is_aws: |
| 261 | + for model_info in self.__client.models.list(): |
| 262 | + rv.add(model_info.id) |
| 263 | + else: |
| 264 | + bedrock = boto3.client(service_name="bedrock") |
| 265 | + response = bedrock.list_foundation_models(byProvider="anthropic") |
| 266 | + for model_info in response["modelSummaries"]: |
| 267 | + rv.add(model_info["modelId"]) |
| 268 | + |
| 269 | + return rv |
251 | 270 |
|
252 | 271 | def is_model_supported(self, model: str) -> bool: |
253 | | - return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix) |
| 272 | + if not self.__is_aws: |
| 273 | + return model in self.get_models() |
| 274 | + else: |
| 275 | + return any(True for model_id in self.get_models() if model.endswith(model_id)) |
254 | 276 |
|
255 | 277 | def is_prompt_supported( |
256 | 278 | self, |
|
0 commit comments