Skip to content

Commit c3a7058

Browse files
ChiaXinLiangDouweM
andauthored
Fix StructuredDict with nested JSON schemas using $ref (#2570)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 5b00781 commit c3a7058

20 files changed

+115
-24
lines changed

pydantic_ai_slim/pydantic_ai/profiles/_json_schema.py renamed to pydantic_ai_slim/pydantic_ai/_json_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any, Literal
88

9-
from pydantic_ai.exceptions import UserError
9+
from .exceptions import UserError
1010

1111
JsonSchema = dict[str, Any]
1212

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import time
88
import uuid
9-
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator
9+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Iterator
1010
from contextlib import asynccontextmanager, suppress
1111
from dataclasses import dataclass, fields, is_dataclass
1212
from datetime import datetime, timezone
@@ -70,16 +70,33 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
7070

7171
if schema.get('type') == 'object':
7272
return schema
73-
elif schema.get('$ref') is not None:
74-
maybe_result = schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
75-
76-
if "'$ref': '#/$defs/" in str(maybe_result):
77-
return schema # We can't remove the $defs because the schema contains other references
78-
return maybe_result
73+
elif ref := schema.get('$ref'):
74+
prefix = '#/$defs/'
75+
# Return the referenced schema unless it contains additional nested references.
76+
if (
77+
ref.startswith(prefix)
78+
and (resolved := schema.get('$defs', {}).get(ref[len(prefix) :]))
79+
and resolved.get('type') == 'object'
80+
and not _contains_ref(resolved)
81+
):
82+
return resolved
83+
return schema
7984
else:
8085
raise UserError('Schema must be an object')
8186

8287

88+
def _contains_ref(obj: JsonSchemaValue | list[JsonSchemaValue]) -> bool:
89+
"""Recursively check if an object contains any $ref keys."""
90+
items: Iterable[JsonSchemaValue]
91+
if isinstance(obj, dict):
92+
if '$ref' in obj:
93+
return True
94+
items = obj.values()
95+
else:
96+
items = obj
97+
return any(isinstance(item, dict | list) and _contains_ref(item) for item in items) # pyright: ignore[reportUnknownArgumentType]
98+
99+
83100
T = TypeVar('T')
84101

85102

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing_extensions import TypeAliasType, TypedDict
2121

2222
from .. import _utils
23+
from .._json_schema import JsonSchemaTransformer
2324
from .._output import OutputObjectDefinition
2425
from .._parts_manager import ModelResponsePartsManager
2526
from .._run_context import RunContext
@@ -40,7 +41,6 @@
4041
)
4142
from ..output import OutputMode
4243
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
43-
from ..profiles._json_schema import JsonSchemaTransformer
4444
from ..settings import ModelSettings
4545
from ..tools import ToolDefinition
4646
from ..usage import RequestUsage

pydantic_ai_slim/pydantic_ai/output.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing_extensions import TypeAliasType, TypeVar, deprecated
1111

1212
from . import _utils
13+
from ._json_schema import InlineDefsJsonSchemaTransformer
1314
from .messages import ToolCallPart
1415
from .tools import DeferredToolRequests, ObjectJsonSchema, RunContext, ToolDefinition
1516

@@ -311,6 +312,11 @@ def StructuredDict(
311312
"""
312313
json_schema = _utils.check_object_json_schema(json_schema)
313314

315+
# Pydantic `TypeAdapter` fails when `object.__get_pydantic_json_schema__` has `$defs`, so we inline them
316+
# See https://github.com/pydantic/pydantic/issues/12145
317+
if '$defs' in json_schema:
318+
json_schema = InlineDefsJsonSchemaTransformer(json_schema).walk()
319+
314320
if name:
315321
json_schema['title'] = name
316322

pydantic_ai_slim/pydantic_ai/profiles/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from typing_extensions import Self
88

9+
from .._json_schema import InlineDefsJsonSchemaTransformer, JsonSchemaTransformer
910
from ..output import StructuredOutputMode
10-
from ._json_schema import InlineDefsJsonSchemaTransformer, JsonSchemaTransformer
1111

1212
__all__ = [
1313
'ModelProfile',

pydantic_ai_slim/pydantic_ai/profiles/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from pydantic_ai.exceptions import UserError
66

7+
from .._json_schema import JsonSchema, JsonSchemaTransformer
78
from . import ModelProfile
8-
from ._json_schema import JsonSchema, JsonSchemaTransformer
99

1010

1111
def google_model_profile(model_name: str) -> ModelProfile | None:

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from dataclasses import dataclass
77
from typing import Any, Literal
88

9+
from .._json_schema import JsonSchema, JsonSchemaTransformer
910
from . import ModelProfile
10-
from ._json_schema import JsonSchema, JsonSchemaTransformer
1111

1212
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
1313

tests/models/test_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
UserError,
3535
UserPromptPart,
3636
)
37+
from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
3738
from pydantic_ai.builtin_tools import WebSearchTool
3839
from pydantic_ai.models import ModelRequestParameters
3940
from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
40-
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
4141
from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile
4242
from pydantic_ai.result import RunUsage
4343
from pydantic_ai.settings import ModelSettings

tests/providers/test_azure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from inline_snapshot import snapshot
55
from pytest_mock import MockerFixture
66

7+
from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
78
from pydantic_ai.agent import Agent
8-
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
99
from pydantic_ai.profiles.cohere import cohere_model_profile
1010
from pydantic_ai.profiles.deepseek import deepseek_model_profile
1111
from pydantic_ai.profiles.grok import grok_model_profile

tests/providers/test_bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from pytest_mock import MockerFixture
55

6-
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
6+
from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
77
from pydantic_ai.profiles.amazon import amazon_model_profile
88
from pydantic_ai.profiles.anthropic import anthropic_model_profile
99
from pydantic_ai.profiles.cohere import cohere_model_profile

0 commit comments

Comments
 (0)