|
6 | 6 |
|
7 | 7 | import magic |
8 | 8 | from google import genai |
| 9 | +from google.auth.credentials import Credentials |
9 | 10 | from google.genai import types |
10 | 11 | from google.genai.types import ( |
11 | 12 | CountTokensConfig, |
|
26 | 27 | from openai.types.chat.chat_completion import ChatCompletion, Choice |
27 | 28 | from pydantic import BaseModel |
28 | 29 | from pydantic_ai.messages import ModelMessage, ModelResponse |
29 | | -from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse |
| 30 | +from pydantic_ai.models import Model as PydanticAiModel |
| 31 | +from pydantic_ai.models import ModelRequestParameters, StreamedResponse |
30 | 32 | from pydantic_ai.models.gemini import GeminiModel |
31 | 33 | from pydantic_ai.settings import ModelSettings |
32 | 34 | from pydantic_ai.usage import Usage |
@@ -54,27 +56,29 @@ class GoogleLlmClient(LlmClient): |
54 | 56 | ] |
55 | 57 | __MODEL_PREFIX = "models/" |
56 | 58 |
|
57 | | - def __init__(self, api_key: str, location: Optional[str] = None): |
| 59 | + def __init__(self, api_key: str, is_gcp: bool = False): |
58 | 60 | self.__api_key = api_key |
59 | | - self.__location = location |
60 | | - self.client = genai.Client(api_key=api_key, location=location) |
| 61 | + self.__is_gcp = is_gcp |
| 62 | + if not is_gcp: |
| 63 | + self.client = genai.Client(api_key=api_key) |
| 64 | + else: |
| 65 | + self.client = genai.Client(api_key=api_key, vertexai=True, credentials=Credentials()) |
61 | 66 |
|
62 | 67 | @lru_cache(maxsize=1) |
63 | 68 | def __get_models_info(self) -> list[Model]: |
64 | 69 | return list(self.client.models.list()) |
65 | 70 |
|
66 | | - def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model: |
| 71 | + def __get_pydantic_model(self, model_settings: ModelSettings | None) -> PydanticAiModel: |
67 | 72 | if model_settings is None: |
68 | 73 | raise ValueError("Model settings cannot be None") |
69 | 74 | model_name = model_settings.get("model") |
70 | 75 | if model_name is None: |
71 | 76 | raise ValueError("Model must be set cannot be None") |
72 | 77 |
|
73 | | - if self.__location is None: |
| 78 | + if not self.__is_gcp: |
74 | 79 | return GeminiModel(model_name, api_key=self.__api_key) |
75 | | - |
76 | | - url_template = f"https://{self.__location}-generativelanguage.googleapis.com/v1beta/models/{{model}}:" |
77 | | - return GeminiModel(model_name, api_key=self.__api_key, url_template=url_template) |
| 80 | + else: |
| 81 | + return GeminiModel(model_name, provider="google-vertex") |
78 | 82 |
|
79 | 83 | async def request( |
80 | 84 | self, |
|
0 commit comments