|
9 | 9 | msg = "google-genai is not installed. Please install it with `pip install any-llm-sdk[google]`" |
10 | 10 | raise ImportError(msg) |
11 | 11 |
|
| 12 | +from pydantic import BaseModel |
| 13 | + |
12 | 14 | from openai.types.chat.chat_completion import ChatCompletion |
13 | 15 | from any_llm.provider import Provider, ApiConfig |
14 | 16 | from any_llm.exceptions import MissingApiKeyError |
@@ -97,6 +99,41 @@ def _convert_messages(messages: list[dict[str, Any]]) -> list[types.Content]: |
97 | 99 | return formatted_messages |
98 | 100 |
|
99 | 101 |
|
| 102 | +def _convert_pydantic_to_google_json( |
| 103 | + pydantic_model: type[BaseModel], messages: list[dict[str, Any]] |
| 104 | +) -> list[dict[str, Any]]: |
| 105 | + """ |
| 106 | + Convert Pydantic model to Google-compatible JSON instructions. |
| 107 | +
|
| 108 | + Following a similar pattern to the DeepSeek provider but adapted for Google. |
| 109 | +
|
| 110 | + Returns: |
| 111 | + modified_messages |
| 112 | + """ |
| 113 | + # Get the JSON schema from the Pydantic model |
| 114 | + schema = pydantic_model.model_json_schema() |
| 115 | + |
| 116 | + # Add JSON instructions to the last user message |
| 117 | + modified_messages = messages.copy() |
| 118 | + if modified_messages and modified_messages[-1]["role"] == "user": |
| 119 | + original_content = modified_messages[-1]["content"] |
| 120 | + json_instruction = f""" |
| 121 | +Please respond with a JSON object that matches the following schema: |
| 122 | +
|
| 123 | +{json.dumps(schema, indent=2)} |
| 124 | +
|
| 125 | +Return the JSON object only, no other text, do not wrap it in ```json or ```. |
| 126 | +
|
| 127 | +{original_content} |
| 128 | +""" |
| 129 | + modified_messages[-1]["content"] = json_instruction |
| 130 | + else: |
| 131 | + msg = "Last message is not a user message" |
| 132 | + raise ValueError(msg) |
| 133 | + |
| 134 | + return modified_messages |
| 135 | + |
| 136 | + |
100 | 137 | class GoogleProvider(Provider): |
101 | 138 | """Google Provider using the new response conversion utilities.""" |
102 | 139 |
|
@@ -133,8 +170,15 @@ def completion( |
133 | 170 | **kwargs: Any, |
134 | 171 | ) -> ChatCompletion: |
135 | 172 | """Create a chat completion using Google GenAI.""" |
136 | | - # Remove unsupported parameters |
137 | | - kwargs = remove_unsupported_params(kwargs, ["response_format", "parallel_tool_calls"]) |
| 173 | + # Handle response_format for Pydantic models |
| 174 | + if "response_format" in kwargs: |
| 175 | + response_format = kwargs.pop("response_format") |
| 176 | + if isinstance(response_format, type) and issubclass(response_format, BaseModel): |
| 177 | + # Convert Pydantic model to Google JSON format |
| 178 | + messages = _convert_pydantic_to_google_json(response_format, messages) |
| 179 | + |
| 180 | + # Remove other unsupported parameters |
| 181 | + kwargs = remove_unsupported_params(kwargs, ["parallel_tool_calls"]) |
138 | 182 |
|
139 | 183 | # Convert tools if present |
140 | 184 | tools = None |
|
0 commit comments