diff --git a/pydantic_ai_slim/pydantic_ai/_json_schema.py b/pydantic_ai_slim/pydantic_ai/_json_schema.py index cbaa180208..bd174f81e2 100644 --- a/pydantic_ai_slim/pydantic_ai/_json_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_json_schema.py @@ -25,7 +25,6 @@ def __init__( *, strict: bool | None = None, prefer_inlined_defs: bool = False, - simplify_nullable_unions: bool = False, ): self.schema = schema @@ -33,7 +32,6 @@ def __init__( self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly self.prefer_inlined_defs = prefer_inlined_defs - self.simplify_nullable_unions = simplify_nullable_unions self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {}) self.refs_stack: list[str] = [] @@ -146,39 +144,11 @@ def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf' handled = [self._handle(member) for member in members] - # convert nullable unions to nullable types - if self.simplify_nullable_unions: - handled = self._simplify_nullable_union(handled) - - if len(handled) == 1: - # In this case, no need to retain the union - return handled[0] | schema - # If we have keys besides the union kind (such as title or discriminator), keep them without modifications schema = schema.copy() schema[union_kind] = handled return schema - @staticmethod - def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]: - # TODO: Should we move this to relevant subclasses? Or is it worth keeping here to make reuse easier? - if len(cases) == 2 and {'type': 'null'} in cases: - # Find the non-null schema - non_null_schema = next( - (item for item in cases if item != {'type': 'null'}), - None, - ) - if non_null_schema: - # Create a new schema based on the non-null part, mark as nullable - new_schema = deepcopy(non_null_schema) - new_schema['nullable'] = True - return [new_schema] - else: # pragma: no cover - # they are both null, so just return one of them - return [cases[0]] - - return cases - class InlineDefsJsonSchemaTransformer(JsonSchemaTransformer): """Transforms the JSON Schema to inline $defs.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 071f65fa66..592cc77ed8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -267,7 +267,7 @@ async def count_tokens( messages, model_settings, model_request_parameters ) - # Annoyingly, the type of `GenerateContentConfigDict.get` is "partially `Unknown`" because `response_schema` includes `typing._UnionGenericAlias`, + # Annoyingly, the type of `GenerateContentConfigDict.get` is "partially `Unknown`" because `response_json_schema` includes `typing._UnionGenericAlias`, # so without this we'd need `pyright: ignore[reportUnknownMemberType]` on every line and wouldn't get type checking anyway. generation_config = cast(dict[str, Any], generation_config) @@ -291,7 +291,7 @@ async def count_tokens( thinking_config=generation_config.get('thinking_config'), media_resolution=generation_config.get('media_resolution'), response_mime_type=generation_config.get('response_mime_type'), - response_schema=generation_config.get('response_schema'), + response_json_schema=generation_config.get('response_json_schema'), ), ) @@ -455,7 +455,7 @@ async def _build_content_and_config( tools=cast(ToolListUnionDict, tools), tool_config=tool_config, response_mime_type=response_mime_type, - response_schema=response_schema, + response_json_schema=response_schema, response_modalities=modalities, ) return contents, config diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index e8a88ac223..2e691bb42c 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -1,9 +1,5 @@ from __future__ import annotations as _annotations -import warnings - -from pydantic_ai.exceptions import UserError - from .._json_schema import JsonSchema, JsonSchemaTransformer from . import ModelProfile @@ -23,35 +19,11 @@ def google_model_profile(model_name: str) -> ModelProfile | None: class GoogleJsonSchemaTransformer(JsonSchemaTransformer): """Transforms the JSON Schema from Pydantic to be suitable for Gemini. - Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations) - a subset of OpenAPI v3.0.3. - - Specifically: - * gemini doesn't allow the `title` keyword to be set - * gemini doesn't allow `$defs` — we need to inline the definitions where possible + Gemini supports [a subset of OpenAPI v3.0.3](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations). """ - def __init__(self, schema: JsonSchema, *, strict: bool | None = None): - super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=True) - def transform(self, schema: JsonSchema) -> JsonSchema: - # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini - additional_properties = schema.pop( - 'additionalProperties', None - ) # don't pop yet so it's included in the warning - if additional_properties: - original_schema = {**schema, 'additionalProperties': additional_properties} - warnings.warn( - '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.' - f' Full schema: {self.schema}\n\n' - f'Source of additionalProperties within the full schema: {original_schema}\n\n' - 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n' - "If Google's APIs are updated to support this properly, please create an issue on the Pydantic AI GitHub" - ' and we will fix this behavior.', - UserWarning, - ) - - schema.pop('title', None) + # Remove properties not supported by Gemini schema.pop('$schema', None) if (const := schema.pop('const', None)) is not None: # Gemini doesn't support const, but it does support enum with a single value @@ -59,24 +31,7 @@ def transform(self, schema: JsonSchema) -> JsonSchema: schema.pop('discriminator', None) schema.pop('examples', None) - # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema - # where we add notes about these properties to the field description? - schema.pop('exclusiveMaximum', None) - schema.pop('exclusiveMinimum', None) - - # Gemini only supports string enums, so we need to convert any enum values to strings. - # Pydantic will take care of transforming the transformed string values to the correct type. - if enum := schema.get('enum'): - schema['type'] = 'string' - schema['enum'] = [str(val) for val in enum] - type_ = schema.get('type') - if 'oneOf' in schema and 'type' not in schema: # pragma: no cover - # This gets hit when we have a discriminated union - # Gemini returns an API error in this case even though it says in its error message it shouldn't... - # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent - schema['anyOf'] = schema.pop('oneOf') - if type_ == 'string' and (fmt := schema.pop('format', None)): description = schema.get('description') if description: @@ -84,23 +39,8 @@ def transform(self, schema: JsonSchema) -> JsonSchema: else: schema['description'] = f'Format: {fmt}' - if '$ref' in schema: - raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}') - - if 'prefixItems' in schema: - # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility - prefix_items = schema.pop('prefixItems') - items = schema.get('items') - unique_items = [items] if items is not None else [] - for item in prefix_items: - if item not in unique_items: - unique_items.append(item) - if len(unique_items) > 1: # pragma: no cover - schema['items'] = {'anyOf': unique_items} - elif len(unique_items) == 1: # pragma: no branch - schema['items'] = unique_items[0] - schema.setdefault('minItems', len(prefix_items)) - if items is None: # pragma: no branch - schema.setdefault('maxItems', len(prefix_items)) + # Note: exclusiveMinimum/exclusiveMaximum are NOT yet supported + schema.pop('exclusiveMinimum', None) + schema.pop('exclusiveMaximum', None) return schema diff --git a/tests/json_body_serializer.py b/tests/json_body_serializer.py index bfb2317c01..a0cadd3259 100644 --- a/tests/json_body_serializer.py +++ b/tests/json_body_serializer.py @@ -76,7 +76,7 @@ def serialize(cassette_dict: Any): # pragma: lax no cover del data['body'] if content_type == ['application/x-www-form-urlencoded']: query_params = urllib.parse.parse_qs(data['body']) - for key in ['client_secret', 'refresh_token']: # pragma: no cover + for key in ['client_id', 'client_secret', 'refresh_token']: # pragma: no cover if key in query_params: query_params[key] = ['scrubbed'] data['body'] = urllib.parse.urlencode(query_params) diff --git a/tests/models/cassettes/test_google/test_google_dict_with_additional_properties_native_output.yaml b/tests/models/cassettes/test_google/test_google_dict_with_additional_properties_native_output.yaml new file mode 100644 index 0000000000..314d164413 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_dict_with_additional_properties_native_output.yaml @@ -0,0 +1,78 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '519' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: Create a config named "api-config" with metadata author="Alice" and version="1.0" + role: user + generationConfig: + responseJsonSchema: + description: A response with configuration metadata. + properties: + metadata: + additionalProperties: + type: string + type: object + name: + type: string + required: + - name + - metadata + title: ConfigResponse + type: object + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '631' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1379 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"name": "api-config", "metadata": {"author": "Alice", "version": "1.0"}}' + role: model + finishReason: STOP + index: 0 + modelVersion: gemini-2.5-flash + responseId: CZMUacOtKv2SxN8Pi7TrsAs + usageMetadata: + candidatesTokenCount: 25 + promptTokenCount: 23 + promptTokensDetails: + - modality: TEXT + tokenCount: 23 + thoughtsTokenCount: 158 + totalTokenCount: 206 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_discriminated_union_native_output.yaml b/tests/models/cassettes/test_google/test_google_discriminated_union_native_output.yaml new file mode 100644 index 0000000000..9febb08370 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_discriminated_union_native_output.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '807' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: Tell me about a cat with a meow volume of 5 + role: user + generationConfig: + responseJsonSchema: + $defs: + Cat: + properties: + meow_volume: + type: integer + pet_type: + default: cat + enum: + - cat + type: string + required: + - meow_volume + title: Cat + type: object + Dog: + properties: + bark_volume: + type: integer + pet_type: + default: dog + enum: + - dog + type: string + required: + - bark_volume + title: Dog + type: object + description: A response containing a pet. + properties: + pet: + oneOf: + - $ref: '#/$defs/Cat' + - $ref: '#/$defs/Dog' + required: + - pet + title: PetResponse + type: object + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '594' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1682 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"pet":{"pet_type":"cat","meow_volume":5}}' + role: model + finishReason: STOP + index: 0 + modelVersion: gemini-2.5-flash + responseId: B5MUaePHJLHd7M8PqfX5qAQ + usageMetadata: + candidatesTokenCount: 16 + promptTokenCount: 14 + promptTokensDetails: + - modality: TEXT + tokenCount: 14 + thoughtsTokenCount: 181 + totalTokenCount: 211 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_integer_enum_native_output.yaml b/tests/models/cassettes/test_google/test_google_integer_enum_native_output.yaml new file mode 100644 index 0000000000..3c16730afe --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_integer_enum_native_output.yaml @@ -0,0 +1,84 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '509' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: Create a task named "Fix bug" with a priority + role: user + generationConfig: + responseJsonSchema: + $defs: + Priority: + enum: + - 1 + - 2 + - 3 + title: Priority + type: integer + description: A task with a priority level. + properties: + name: + type: string + priority: + $ref: '#/$defs/Priority' + required: + - name + - priority + title: Task + type: object + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '584' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=2911 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"name": "Fix bug", "priority": 1}' + role: model + finishReason: STOP + index: 0 + modelVersion: gemini-2.5-flash + responseId: D5MUaYKeH9PjnsEPron42AQ + usageMetadata: + candidatesTokenCount: 13 + promptTokenCount: 12 + promptTokensDetails: + - modality: TEXT + tokenCount: 12 + thoughtsTokenCount: 448 + totalTokenCount: 473 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_optional_fields_native_output.yaml b/tests/models/cassettes/test_google/test_google_optional_fields_native_output.yaml new file mode 100644 index 0000000000..c0d360607a --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_optional_fields_native_output.yaml @@ -0,0 +1,164 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '538' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: Tell me about London, UK with population 9 million + role: user + generationConfig: + responseJsonSchema: + description: A city and its country. + properties: + city: + type: string + country: + anyOf: + - type: string + - type: 'null' + default: null + population: + anyOf: + - type: integer + - type: 'null' + default: null + required: + - city + title: CityLocation + type: object + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '612' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1364 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"city": "London", "country": "UK", "population": 9000000}' + role: model + finishReason: STOP + index: 0 + modelVersion: gemini-2.5-flash + responseId: C5MUaeyaDuzBxN8P-KjW-AY + usageMetadata: + candidatesTokenCount: 24 + promptTokenCount: 12 + promptTokensDetails: + - modality: TEXT + tokenCount: 12 + thoughtsTokenCount: 130 + totalTokenCount: 166 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '514' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: 'Just tell me a city: Paris' + role: user + generationConfig: + responseJsonSchema: + description: A city and its country. + properties: + city: + type: string + country: + anyOf: + - type: string + - type: 'null' + default: null + population: + anyOf: + - type: integer + - type: 'null' + default: null + required: + - city + title: CityLocation + type: object + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '561' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1128 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"city": "Paris"}' + role: model + finishReason: STOP + index: 0 + modelVersion: gemini-2.5-flash + responseId: DJMUaf3RGp7SxN8PlIu6wQY + usageMetadata: + candidatesTokenCount: 6 + promptTokenCount: 8 + promptTokensDetails: + - modality: TEXT + tokenCount: 8 + thoughtsTokenCount: 99 + totalTokenCount: 113 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prefix_items_native_output.yaml b/tests/models/cassettes/test_google/test_google_prefix_items_native_output.yaml new file mode 100644 index 0000000000..8a00918f25 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_prefix_items_native_output.yaml @@ -0,0 +1,78 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '508' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: 'Give me coordinates for New York City: latitude 40.7128, longitude -74.0060' + role: user + generationConfig: + responseJsonSchema: + description: A 2D coordinate with latitude and longitude. + properties: + point: + maxItems: 2 + minItems: 2 + prefixItems: + - type: number + - type: number + type: array + required: + - point + title: Coordinate + type: object + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '573' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1093 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"point":[40.7128,-74.006]}' + role: model + finishReason: STOP + index: 0 + modelVersion: gemini-2.5-flash + responseId: uAMWaebbNvegxN8P06_M2A4 + usageMetadata: + candidatesTokenCount: 18 + promptTokenCount: 28 + promptTokensDetails: + - modality: TEXT + tokenCount: 28 + thoughtsTokenCount: 108 + totalTokenCount: 154 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_recursive_schema_native_output.yaml b/tests/models/cassettes/test_google/test_google_recursive_schema_native_output.yaml new file mode 100644 index 0000000000..0e05d1e1cc --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_recursive_schema_native_output.yaml @@ -0,0 +1,95 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '556' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: Create a simple tree with root "A" and two children "B" and "C" + role: user + generationConfig: + responseJsonSchema: + $defs: + TreeNode: + description: A node in a tree structure. + properties: + children: + default: [] + items: + $ref: '#/$defs/TreeNode' + type: array + value: + type: string + required: + - value + title: TreeNode + type: object + $ref: '#/$defs/TreeNode' + title: TreeNode + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '773' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=859 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.03688501318295797 + content: + parts: + - text: |- + { + "value": "A", + "children": [ + { + "value": "B" + }, + { + "value": "C" + } + ] + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: mpMUaYufEZ2qxN8Pr4qf6Qs + usageMetadata: + candidatesTokenCount: 48 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 48 + promptTokenCount: 19 + promptTokensDetails: + - modality: TEXT + tokenCount: 19 + totalTokenCount: 67 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_recursive_schema_native_output_gemini_2_5.yaml b/tests/models/cassettes/test_google/test_google_recursive_schema_native_output_gemini_2_5.yaml new file mode 100644 index 0000000000..9aa2df2a5a --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_recursive_schema_native_output_gemini_2_5.yaml @@ -0,0 +1,141 @@ +interactions: +- request: + body: grant_type=%5B%27refresh_token%27%5D&client_id=%5B%27scrubbed%27%5D&client_secret=%5B%27scrubbed%27%5D&refresh_token=%5B%27scrubbed%27%5D + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '137' + content-type: + - application/x-www-form-urlencoded + method: POST + uri: https://oauth2.googleapis.com/token + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + cache-control: + - no-cache, no-store, max-age=0, must-revalidate + content-length: + - '1520' + content-type: + - application/json; charset=utf-8 + expires: + - Mon, 01 Jan 1990 00:00:00 GMT + pragma: + - no-cache + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + access_token: scrubbed + expires_in: 3599 + id_token: eyJhbGciOiJSUzI1NiIsImtpZCI6IjRmZWI0NGYwZjdhN2UyN2M3YzQwMzM3OWFmZjIwYWY1YzhjZjUyZGMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiIzMjU1NTk0MDU1OS5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbSIsImF1ZCI6IjMyNTU1OTQwNTU5LmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29tIiwic3ViIjoiMTA0MDMyODc1Njg3NDUwNzA3NzUwIiwiaGQiOiJjYXB0dXJlZGtub3dsZWRnZS5haSIsImVtYWlsIjoiY29ucmFkQGNhcHR1cmVka25vd2xlZGdlLmFpIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsImF0X2hhc2giOiJ2emc3MEN0a1FhcnBJNVMzYWJZY1ZnIiwiaWF0IjoxNzYyOTU2NjUxLCJleHAiOjE3NjI5NjAyNTF9.P0kjqqgbGDIEfRkaCL76T1rRV1CC6ypQjWLlq8IWDgFhA6xMLOgcoN3eCU0yFg8lgoY_SI2C2oaQWMep9dNZbF4yil376ohzyuxkzyjjjfWmf-IuxDS9_s4IbIOut90XLM_R1SxWA-nc_nrki3OeYbvss0BWh28_BAvYLuMI4EVqW5QnlW1VmYj46kgn80YW9PEwSwei1h99ew9KLg7e9Fhb1LIXdU7zu1NkGjbvygirN3NKEZkry55w2U_h8ItPRes0MqJUFqpJzto92-GtpKhPjbIvmPJfmepxec9Tq-VU5IK24RqmYtNmzT5ZgyOXQtUni-9zhKjWsP8kIbGTEg + scope: openid https://www.googleapis.com/auth/accounts.reauth https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/cloud-platform + https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/sqlservice.login + token_type: Bearer + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '556' + content-type: + - application/json + host: + - aiplatform.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: Create a simple tree with root "A" and two children "B" and "C" + role: user + generationConfig: + responseJsonSchema: + $defs: + TreeNode: + description: A node in a tree structure. + properties: + children: + default: [] + items: + $ref: '#/$defs/TreeNode' + type: array + value: + type: string + required: + - value + title: TreeNode + type: object + $ref: '#/$defs/TreeNode' + title: TreeNode + responseMimeType: application/json + responseModalities: + - TEXT + uri: https://aiplatform.googleapis.com/v1beta1/projects/pydantic-ai/locations/global/publishers/google/models/gemini-2.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '882' + content-type: + - application/json; charset=UTF-8 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -1.1582396030426025 + content: + parts: + - text: |- + { + "value": "A", + "children": [ + { + "value": "B" + }, + { + "value": "C" + } + ] + } + role: model + finishReason: STOP + createTime: '2025-11-12T14:10:52.206764Z' + modelVersion: gemini-2.5-flash + responseId: bJUUaazPDI-Kn9kPwNOc-AQ + usageMetadata: + candidatesTokenCount: 48 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 48 + promptTokenCount: 19 + promptTokensDetails: + - modality: TEXT + tokenCount: 19 + thoughtsTokenCount: 153 + totalTokenCount: 220 + trafficType: ON_DEMAND + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index d07e23452e..e1195e55aa 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -45,7 +45,6 @@ _gemini_streamed_response_ta, _GeminiCandidates, _GeminiContent, - _GeminiFunction, _GeminiFunctionCall, _GeminiFunctionCallingConfig, _GeminiFunctionCallPart, @@ -55,7 +54,6 @@ _GeminiTextPart, _GeminiThoughtPart, _GeminiToolConfig, - _GeminiTools, _GeminiUsageMetaData, _metadata_as_usage, ) @@ -135,32 +133,39 @@ async def test_model_tools(allow_model_requests: None): tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) assert tools == snapshot( - _GeminiTools( - function_declarations=[ - _GeminiFunction( - name='foo', - description='This is foo', - parameters={'type': 'object', 'properties': {'bar': {'type': 'number'}}}, - ), - _GeminiFunction( - name='apple', - description='This is apple', - parameters={ + { + 'function_declarations': [ + { + 'name': 'foo', + 'description': 'This is foo', + 'parameters': { 'type': 'object', - 'properties': {'banana': {'type': 'array', 'items': {'type': 'number'}}}, + 'title': 'Foo', + 'properties': {'bar': {'type': 'number', 'title': 'Bar'}}, }, - ), - _GeminiFunction( - name='result', - description='This is the tool for the final Result', - parameters={ + }, + { + 'name': 'apple', + 'description': 'This is apple', + 'parameters': { + 'type': 'object', + 'properties': { + 'banana': {'type': 'array', 'title': 'Banana', 'items': {'type': 'number', 'title': 'Bar'}} + }, + }, + }, + { + 'name': 'result', + 'description': 'This is the tool for the final Result', + 'parameters': { 'type': 'object', + 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam'], }, - ), + }, ] - ) + } ) assert tool_config is None @@ -183,18 +188,15 @@ async def test_require_response_tool(allow_model_requests: None): tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) assert tools == snapshot( - _GeminiTools( - function_declarations=[ - _GeminiFunction( - name='result', - description='This is the tool for the final Result', - parameters={ - 'type': 'object', - 'properties': {'spam': {'type': 'number'}}, - }, - ), + { + 'function_declarations': [ + { + 'name': 'result', + 'description': 'This is the tool for the final Result', + 'parameters': {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, + } ] - ) + } ) assert tool_config == snapshot( _GeminiToolConfig( @@ -282,45 +284,44 @@ class Locations(BaseModel): 'parameters': { 'properties': { 'locations': { - 'items': { - 'properties': { - 'lat': {'type': 'number'}, - 'lng': {'default': 1.1, 'type': 'number'}, - 'chart': { - 'properties': { - 'x_axis': { - 'properties': { - 'label': { - 'default': '', - 'description': 'The label of the axis', - 'type': 'string', - } - }, - 'type': 'object', - }, - 'y_axis': { - 'properties': { - 'label': { - 'default': '', - 'description': 'The label of the axis', - 'type': 'string', - } - }, - 'type': 'object', - }, - }, - 'required': ['x_axis', 'y_axis'], - 'type': 'object', - }, - }, - 'required': ['lat', 'chart'], - 'type': 'object', - }, + 'items': {'$ref': '#/$defs/Location'}, + 'title': 'Locations', 'type': 'array', } }, 'required': ['locations'], + 'title': 'Locations', 'type': 'object', + '$defs': { + 'Axis': { + 'properties': { + 'label': { + 'default': '', + 'description': 'The label of the axis', + 'title': 'Label', + 'type': 'string', + } + }, + 'title': 'Axis', + 'type': 'object', + }, + 'Chart': { + 'properties': {'x_axis': {'$ref': '#/$defs/Axis'}, 'y_axis': {'$ref': '#/$defs/Axis'}}, + 'required': ['x_axis', 'y_axis'], + 'title': 'Chart', + 'type': 'object', + }, + 'Location': { + 'properties': { + 'lat': {'title': 'Lat', 'type': 'number'}, + 'lng': {'default': 1.1, 'title': 'Lng', 'type': 'number'}, + 'chart': {'$ref': '#/$defs/Chart'}, + }, + 'required': ['lat', 'chart'], + 'title': 'Location', + 'type': 'object', + }, + }, }, } ] @@ -379,13 +380,19 @@ class QueryDetails(BaseModel): 'parameters': { 'properties': { 'progress': { - 'items': {'enum': ['100', '80', '60', '40', '20'], 'type': 'string'}, - 'type': 'array', - 'nullable': True, 'default': None, + 'title': 'Progress', + 'anyOf': [ + {'items': {'$ref': '#/$defs/ProgressEnum'}, 'type': 'array'}, + {'type': 'null'}, + ], } }, + 'title': 'QueryDetails', 'type': 'object', + '$defs': { + 'ProgressEnum': {'enum': [100, 80, 60, 40, 20], 'title': 'ProgressEnum', 'type': 'integer'} + }, }, } ] @@ -425,18 +432,21 @@ class Locations(BaseModel): 'description': 'This is the tool for the final Result', 'parameters': { 'properties': { - 'op_location': { + 'op_location': {'default': None, 'anyOf': [{'$ref': '#/$defs/Location'}, {'type': 'null'}]} + }, + 'title': 'Locations', + 'type': 'object', + '$defs': { + 'Location': { 'properties': { - 'lat': {'type': 'number'}, - 'lng': {'type': 'number'}, + 'lat': {'title': 'Lat', 'type': 'number'}, + 'lng': {'title': 'Lng', 'type': 'number'}, }, 'required': ['lat', 'lng'], - 'nullable': True, + 'title': 'Location', 'type': 'object', - 'default': None, } }, - 'type': 'object', }, } ] @@ -444,52 +454,6 @@ class Locations(BaseModel): ) -async def test_json_def_recursive(allow_model_requests: None): - class Location(BaseModel): - lat: float - lng: float - nested_locations: list[Location] - - json_schema = Location.model_json_schema() - assert json_schema == snapshot( - { - '$defs': { - 'Location': { - 'properties': { - 'lat': {'title': 'Lat', 'type': 'number'}, - 'lng': {'title': 'Lng', 'type': 'number'}, - 'nested_locations': { - 'items': {'$ref': '#/$defs/Location'}, - 'title': 'Nested Locations', - 'type': 'array', - }, - }, - 'required': ['lat', 'lng', 'nested_locations'], - 'title': 'Location', - 'type': 'object', - } - }, - '$ref': '#/$defs/Location', - } - ) - - m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg')) - output_tool = ToolDefinition( - name='result', - description='This is the tool for the final Result', - parameters_json_schema=json_schema, - ) - with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): - mrp = ModelRequestParameters( - function_tools=[], - allow_text_output=True, - output_tools=[output_tool], - output_mode='text', - output_object=None, - ) - mrp = m.customize_request_parameters(mrp) - - async def test_json_def_date(allow_model_requests: None): class FormattedStringFields(BaseModel): d: datetime.date @@ -527,24 +491,25 @@ class FormattedStringFields(BaseModel): ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( - _GeminiTools( - function_declarations=[ - _GeminiFunction( - description='This is the tool for the final Result', - name='result', - parameters={ + { + 'function_declarations': [ + { + 'name': 'result', + 'description': 'This is the tool for the final Result', + 'parameters': { 'properties': { - 'd': {'description': 'Format: date', 'type': 'string'}, - 'dt': {'description': 'Format: date-time', 'type': 'string'}, - 't': {'description': 'Format: time', 'type': 'string'}, - 'td': {'description': 'my timedelta (format: duration)', 'type': 'string'}, + 'd': {'title': 'D', 'type': 'string', 'description': 'Format: date'}, + 'dt': {'title': 'Dt', 'type': 'string', 'description': 'Format: date-time'}, + 't': {'description': 'Format: time', 'title': 'T', 'type': 'string'}, + 'td': {'description': 'my timedelta (format: duration)', 'title': 'Td', 'type': 'string'}, }, 'required': ['d', 'dt', 't', 'td'], + 'title': 'FormattedStringFields', 'type': 'object', }, - ) + } ] - ) + } ) @@ -1449,19 +1414,18 @@ async def get_temperature(location: CurrentLocation) -> float: # pragma: no cov @pytest.mark.vcr() async def test_gemini_additional_properties_is_true(allow_model_requests: None, gemini_api_key: str): + """Test that additionalProperties with schemas now work natively (no warning since Nov 2025 announcement).""" m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) agent = Agent(m) - with pytest.warns(UserWarning, match='.*additionalProperties.*'): - - @agent.tool_plain - async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pragma: no cover - return 20.0 + @agent.tool_plain + async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pragma: no cover + return 20.0 - result = await agent.run('What is the temperature in Tokyo?') - assert result.output == snapshot( - 'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n' - ) + result = await agent.run('What is the temperature in Tokyo?') + assert result.output == snapshot( + 'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n' + ) @pytest.mark.vcr() diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 82332f38ef..c59638f46c 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -3154,6 +3154,170 @@ async def test_google_httpx_client_is_not_closed(allow_model_requests: None, gem assert result.output == snapshot('The capital of Mexico is **Mexico City**.') +async def test_google_discriminated_union_native_output(allow_model_requests: None, google_provider: GoogleProvider): + """Test discriminated unions with oneOf and discriminator field.""" + from typing import Literal + + from pydantic import Field + + m = GoogleModel('gemini-2.5-flash', provider=google_provider) + + class Cat(BaseModel): + pet_type: Literal['cat'] = 'cat' + meow_volume: int + + class Dog(BaseModel): + pet_type: Literal['dog'] = 'dog' + bark_volume: int + + class PetResponse(BaseModel): + """A response containing a pet.""" + + pet: Cat | Dog = Field(discriminator='pet_type') + + agent = Agent(m, output_type=NativeOutput(PetResponse)) + + result = await agent.run('Tell me about a cat with a meow volume of 5') + assert result.output.pet.pet_type == 'cat' + assert isinstance(result.output.pet, Cat) + assert result.output.pet.meow_volume == snapshot(5) + + +async def test_google_recursive_schema_native_output(allow_model_requests: None, google_provider: GoogleProvider): + """Test recursive schemas with $ref and $defs.""" + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class TreeNode(BaseModel): + """A node in a tree structure.""" + + value: str + children: list[TreeNode] = [] + + agent = Agent(m, output_type=NativeOutput(TreeNode)) + + result = await agent.run('Create a simple tree with root "A" and two children "B" and "C"') + assert result.output.value == snapshot('A') + assert len(result.output.children) == snapshot(2) + assert {child.value for child in result.output.children} == snapshot({'B', 'C'}) + + +async def test_google_recursive_schema_native_output_gemini_2_5( + allow_model_requests: None, vertex_provider: GoogleProvider +): # pragma: lax no cover + """Test recursive schemas with $ref and $defs using gemini-2.5-flash on Vertex AI. + + NOTE: Recursive schemas with gemini-2.5-flash FAIL on GLA (500 error) but PASS on Vertex AI. + This test uses vertex_provider to demonstrate the feature works on Vertex AI. + The GLA issue needs to be reported to Google. + """ + m = GoogleModel('gemini-2.5-flash', provider=vertex_provider) + + class TreeNode(BaseModel): + """A node in a tree structure.""" + + value: str + children: list[TreeNode] = [] + + agent = Agent(m, output_type=NativeOutput(TreeNode)) + + result = await agent.run('Create a simple tree with root "A" and two children "B" and "C"') + assert result.output.value == 'A' + assert len(result.output.children) == 2 + assert {child.value for child in result.output.children} == {'B', 'C'} + + +async def test_google_dict_with_additional_properties_native_output( + allow_model_requests: None, google_provider: GoogleProvider +): + """Test dicts with additionalProperties.""" + m = GoogleModel('gemini-2.5-flash', provider=google_provider) + + class ConfigResponse(BaseModel): + """A response with configuration metadata.""" + + name: str + metadata: dict[str, str] + + agent = Agent(m, output_type=NativeOutput(ConfigResponse)) + + result = await agent.run('Create a config named "api-config" with metadata author="Alice" and version="1.0"') + assert result.output.name == snapshot('api-config') + assert result.output.metadata == snapshot({'author': 'Alice', 'version': '1.0'}) + + +async def test_google_optional_fields_native_output(allow_model_requests: None, google_provider: GoogleProvider): + """Test optional/nullable fields with type: 'null'.""" + m = GoogleModel('gemini-2.5-flash', provider=google_provider) + + class CityLocation(BaseModel): + """A city and its country.""" + + city: str + country: str | None = None + population: int | None = None + + agent = Agent(m, output_type=NativeOutput(CityLocation)) + + # Test with all fields provided + result = await agent.run('Tell me about London, UK with population 9 million') + assert result.output.city == snapshot('London') + assert result.output.country == snapshot('UK') + assert result.output.population is not None + + # Test with optional fields as None + result2 = await agent.run('Just tell me a city: Paris') + assert result2.output.city == snapshot('Paris') + + +async def test_google_integer_enum_native_output(allow_model_requests: None, google_provider: GoogleProvider): + """Test integer enums work natively without string conversion.""" + from enum import IntEnum + + m = GoogleModel('gemini-2.5-flash', provider=google_provider) + + class Priority(IntEnum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + class Task(BaseModel): + """A task with a priority level.""" + + name: str + priority: Priority + + agent = Agent(m, output_type=NativeOutput(Task)) + + result = await agent.run('Create a task named "Fix bug" with a priority') + assert result.output.name == snapshot('Fix bug') + # Verify it returns a valid Priority enum (any value is fine, we're testing schema support) + assert isinstance(result.output.priority, Priority) + assert result.output.priority in {Priority.LOW, Priority.MEDIUM, Priority.HIGH} + # Verify it's an actual integer value + assert isinstance(result.output.priority.value, int) + + +async def test_google_prefix_items_native_output(allow_model_requests: None, google_provider: GoogleProvider): + """Test prefixItems (tuple types) work natively without conversion to items.""" + m = GoogleModel('gemini-2.5-flash', provider=google_provider) + + class Coordinate(BaseModel): + """A 2D coordinate with latitude and longitude.""" + + point: tuple[float, float] # This generates prefixItems in JSON schema + + agent = Agent(m, output_type=NativeOutput(Coordinate)) + + result = await agent.run('Give me coordinates for New York City: latitude 40.7128, longitude -74.0060') + assert len(result.output.point) == snapshot(2) + # Verify both values are floats + assert isinstance(result.output.point[0], float) + assert isinstance(result.output.point[1], float) + # Rough check for NYC coordinates (latitude ~40, longitude ~-74) + assert 40 <= result.output.point[0] <= 41 + assert -75 <= result.output.point[1] <= -73 + + def test_google_process_response_filters_empty_text_parts(google_provider: GoogleProvider): model = GoogleModel('gemini-2.5-pro', provider=google_provider) response = _generate_response_with_texts(response_id='resp-123', texts=['', 'first', '', 'second'])