|
7 | 7 |
|
8 | 8 | import httpx |
9 | 9 | import requests |
10 | | -from typing import Optional, Any, Dict, List |
| 10 | +from pydantic import BaseModel |
| 11 | +from typing import get_args, get_origin, List, Optional, Dict, Any, Union |
11 | 12 |
|
12 | 13 | import tuneapi.utils as tu |
13 | 14 | import tuneapi.types as tt |
@@ -110,6 +111,106 @@ def _process_header(self): |
110 | 111 | "Content-Type": "application/json", |
111 | 112 | } |
112 | 113 |
|
| 114 | + @staticmethod |
| 115 | + def get_structured_schema(model: type[BaseModel]) -> Dict[str, Any]: |
| 116 | + """ |
| 117 | + Converts a Pydantic BaseModel to a JSON schema compatible with Gemini API, |
| 118 | + including `anyOf` for optional or union types and handling nested structures correctly. |
| 119 | +
|
| 120 | + Args: |
| 121 | + model: The Pydantic BaseModel class to convert. |
| 122 | +
|
| 123 | + Returns: |
| 124 | + A dictionary representing the JSON schema. |
| 125 | + """ |
| 126 | + |
| 127 | + def _process_field( |
| 128 | + field_name: str, field_type: Any, field_description: str = None |
| 129 | + ) -> dict: |
| 130 | + """Helper function to process a single field.""" |
| 131 | + schema = {} |
| 132 | + origin = get_origin(field_type) |
| 133 | + args = get_args(field_type) |
| 134 | + |
| 135 | + if origin is list: |
| 136 | + schema["type"] = "array" |
| 137 | + if args: |
| 138 | + item_schema = _process_field_type(args[0]) |
| 139 | + schema["items"] = item_schema |
| 140 | + if "type" not in item_schema and "anyOf" not in item_schema: |
| 141 | + schema["items"]["type"] = "object" # default item type for list |
| 142 | + else: |
| 143 | + schema["items"] = {} |
| 144 | + elif origin is Optional: |
| 145 | + if args: |
| 146 | + inner_schema = _process_field_type(args[0]) |
| 147 | + schema["anyOf"] = [inner_schema, {"type": "null"}] |
| 148 | + else: |
| 149 | + schema = {"type": "null"} |
| 150 | + elif origin is dict: |
| 151 | + schema["type"] = "object" |
| 152 | + if len(args) == 2: |
| 153 | + schema["additionalProperties"] = _process_field_type(args[1]) |
| 154 | + else: |
| 155 | + schema = _process_field_type(field_type) |
| 156 | + |
| 157 | + if field_description: |
| 158 | + schema["description"] = field_description |
| 159 | + return schema |
| 160 | + |
| 161 | + def _process_field_type(field_type: Any) -> dict: |
| 162 | + """Helper function to process the type of a field.""" |
| 163 | + |
| 164 | + origin = get_origin(field_type) |
| 165 | + args = get_args(field_type) |
| 166 | + |
| 167 | + if field_type is str: |
| 168 | + return {"type": "string"} |
| 169 | + elif field_type is int: |
| 170 | + return {"type": "integer"} |
| 171 | + elif field_type is float: |
| 172 | + return {"type": "number"} |
| 173 | + elif field_type is bool: |
| 174 | + return {"type": "boolean"} |
| 175 | + elif isinstance(field_type, type) and issubclass(field_type, BaseModel): |
| 176 | + return Gemini.get_structured_schema( |
| 177 | + field_type |
| 178 | + ) # Recursive call for nested models |
| 179 | + elif origin is list: |
| 180 | + schema = {"type": "array"} |
| 181 | + if args: |
| 182 | + item_schema = _process_field_type(args[0]) |
| 183 | + schema["items"] = item_schema |
| 184 | + if "type" not in item_schema and "anyOf" not in item_schema: |
| 185 | + schema["items"]["type"] = "object" |
| 186 | + return schema |
| 187 | + elif origin is Optional: |
| 188 | + return _process_field_type(args[0]) |
| 189 | + elif origin is dict: |
| 190 | + schema = {"type": "object"} |
| 191 | + if len(args) == 2: |
| 192 | + schema["additionalProperties"] = _process_field_type(args[1]) |
| 193 | + return schema |
| 194 | + elif origin is Union: |
| 195 | + return _process_field_type(args[0]) |
| 196 | + else: |
| 197 | + return {"type": "string"} # default any object to string |
| 198 | + |
| 199 | + schema = {"type": "object", "properties": {}, "required": []} |
| 200 | + |
| 201 | + for field_name, field in model.model_fields.items(): |
| 202 | + field_description = field.description |
| 203 | + if field.is_required(): |
| 204 | + schema["required"].append(field_name) |
| 205 | + |
| 206 | + schema["properties"][field_name] = _process_field( |
| 207 | + field_name, field.annotation, field_description |
| 208 | + ) |
| 209 | + |
| 210 | + if model.__doc__: |
| 211 | + schema["description"] = model.__doc__.strip() |
| 212 | + return schema |
| 213 | + |
113 | 214 | def chat( |
114 | 215 | self, |
115 | 216 | chats: tt.Thread | str, |
@@ -139,11 +240,13 @@ def chat( |
139 | 240 | output = x |
140 | 241 | else: |
141 | 242 | output += x |
142 | | - except Exception as e: |
143 | | - if not x: |
144 | | - raise e |
145 | | - else: |
146 | | - raise ValueError(x) |
| 243 | + except requests.HTTPError as e: |
| 244 | + print(e.response.text) |
| 245 | + raise e |
| 246 | + |
| 247 | + if chats.schema: |
| 248 | + output = chats.schema(**tu.from_json(output)) |
| 249 | + return output |
147 | 250 | return output |
148 | 251 |
|
149 | 252 | def stream_chat( |
@@ -198,11 +301,11 @@ def stream_chat( |
198 | 301 | "stopSequences": [], |
199 | 302 | } |
200 | 303 |
|
201 | | - if chats.gen_schema: |
| 304 | + if chats.schema: |
202 | 305 | generation_config.update( |
203 | 306 | { |
204 | 307 | "response_mime_type": "application/json", |
205 | | - "response_schema": chats.gen_schema, |
| 308 | + "response_schema": self.get_structured_schema(chats.schema), |
206 | 309 | } |
207 | 310 | ) |
208 | 311 | data["generationConfig"] = generation_config |
@@ -376,11 +479,11 @@ async def stream_chat_async( |
376 | 479 | "stopSequences": [], |
377 | 480 | } |
378 | 481 |
|
379 | | - if chats.gen_schema: |
| 482 | + if chats.schema: |
380 | 483 | generation_config.update( |
381 | 484 | { |
382 | 485 | "response_mime_type": "application/json", |
383 | | - "response_schema": chats.gen_schema, |
| 486 | + "response_schema": chats.schema, |
384 | 487 | } |
385 | 488 | ) |
386 | 489 | data["generationConfig"] = generation_config |
|
0 commit comments