Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 283 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal
from typing import Any, Literal, cast

from .exceptions import UserError

JsonSchema = dict[str, Any]

__all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer']


@dataclass(init=False)
class JsonSchemaTransformer(ABC):
Expand All @@ -26,14 +28,18 @@ def __init__(
strict: bool | None = None,
prefer_inlined_defs: bool = False,
simplify_nullable_unions: bool = False, # TODO (v2): Remove this, no longer used
flatten_allof: bool = False,
):
self.schema = schema

self.strict = strict
self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly
# Can be set to False by subclasses to set `strict` on `ToolDefinition`
# when not set explicitly by the user.
self.is_strict_compatible = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, it's now ambiguous which line the comment belongs to :) Can you move it back please, even though the line is uncomfortable long? Or add a blank line between the 2 vars?


self.prefer_inlined_defs = prefer_inlined_defs
self.simplify_nullable_unions = simplify_nullable_unions
self.flatten_allof = flatten_allof

self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
self.refs_stack: list[str] = []
Expand Down Expand Up @@ -73,6 +79,10 @@ def walk(self) -> JsonSchema:
return handled

def _handle(self, schema: JsonSchema) -> JsonSchema:
# Flatten allOf if requested, before processing the schema
if self.flatten_allof:
schema = _recurse_flatten_allof(schema)

nested_refs = 0
if self.prefer_inlined_defs:
while ref := schema.get('$ref'):
Expand Down Expand Up @@ -187,3 +197,274 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):

def transform(self, schema: JsonSchema) -> JsonSchema:
return schema


def _get_type_set(schema: JsonSchema) -> set[str] | None:
"""Extract type(s) from a schema as a set of strings."""
schema_type = schema.get('type')
if isinstance(schema_type, list):
return {str(t) for t in cast(list[Any], schema_type)}
if isinstance(schema_type, str):
return {schema_type}
return None


def _process_nested_schemas_without_allof(s: JsonSchema) -> JsonSchema:
"""Process nested schemas recursively when there is no allOf at the current level."""
schema_type = s.get('type')
if schema_type == 'object':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have pretty similar code here:

type_ = schema.get('type')
if type_ == 'object':
schema = self._handle_object(schema)
elif type_ == 'array':
schema = self._handle_array(schema)
elif type_ is None:
schema = self._handle_union(schema, 'anyOf')
schema = self._handle_union(schema, 'oneOf')

Would it be an option to merge this new processing logic into that one, somehow? Or at least make them less seemingly-duplicative? I just don't love the idea of recursively walking the JSON schema twice with different implementations.

if isinstance(s.get('properties'), dict):
s['properties'] = {
k: _recurse_flatten_allof(cast(JsonSchema, v))
for k, v in s['properties'].items()
if isinstance(v, dict)
}
if isinstance(s.get('additionalProperties'), dict):
s['additionalProperties'] = _recurse_flatten_allof(cast(JsonSchema, s['additionalProperties']))
if isinstance(s.get('patternProperties'), dict):
s['patternProperties'] = {
k: _recurse_flatten_allof(cast(JsonSchema, v))
for k, v in s['patternProperties'].items()
if isinstance(v, dict)
}
elif schema_type == 'array':
if isinstance(s.get('items'), dict):
s['items'] = _recurse_flatten_allof(cast(JsonSchema, s['items']))
return s


def _collect_base_schema_data(
result: JsonSchema,
) -> tuple[dict[str, JsonSchema], set[str], dict[str, JsonSchema], list[Any], list[set[str]]]:
"""Collect data from base schema: properties, required, patternProperties, additionalProperties."""
properties: dict[str, JsonSchema] = {}
required: set[str] = set()
pattern_properties: dict[str, JsonSchema] = {}
additional_values: list[Any] = []
restricted_property_sets: list[set[str]] = []

base_properties = (
cast(dict[str, JsonSchema], result.get('properties', {})) if isinstance(result.get('properties'), dict) else {}
)
base_additional = result.get('additionalProperties')

if base_properties:
properties.update(base_properties)
if isinstance(result.get('required'), list):
required.update(result['required'])
if isinstance(result.get('patternProperties'), dict):
pattern_properties.update(result['patternProperties'])
if base_additional is False:
additional_values.append(False)
# Only restrict if base schema has properties; if base has no properties but additionalProperties: False,
# it means no additional properties are allowed, but properties from allOf members are still valid
if base_properties:
restricted_property_sets.append(set(base_properties.keys()))

return properties, required, pattern_properties, additional_values, restricted_property_sets


def _collect_member_data(
processed_members: list[JsonSchema],
properties: dict[str, JsonSchema],
required: set[str],
pattern_properties: dict[str, JsonSchema],
additional_values: list[Any],
restricted_property_sets: list[set[str]],
members_properties: list[dict[str, JsonSchema]],
members_additional_props: list[Any],
) -> None:
"""Collect data from allOf members and update the collections."""
for m in processed_members:
member_props = (
cast(dict[str, JsonSchema], m.get('properties', {})) if isinstance(m.get('properties'), dict) else {}
)
members_properties.append(member_props)
members_additional_props.append(m.get('additionalProperties'))

if member_props:
properties.update(member_props)
if isinstance(m.get('required'), list):
required.update(m['required'])
if isinstance(m.get('patternProperties'), dict):
pattern_properties.update(m['patternProperties'])
if 'additionalProperties' in m:
additional_values.append(m['additionalProperties'])
if m['additionalProperties'] is False:
restricted_property_sets.append(set(member_props.keys()))


def _filter_by_restricted_property_sets(
properties: dict[str, JsonSchema], required: set[str], restricted_property_sets: list[set[str]]
) -> tuple[dict[str, JsonSchema], set[str]]:
"""Filter properties and required by restricted property sets (intersection when some/all have additionalProperties: False)."""
if not restricted_property_sets:
return properties, required

# Intersection of allowed properties from all members with additionalProperties: False
allowed_names = restricted_property_sets[0].copy()
for prop_set in restricted_property_sets[1:]:
allowed_names &= prop_set
# Filter properties to only include allowed names
if allowed_names:
properties = {k: v for k, v in properties.items() if k in allowed_names}
required = {r for r in required if r in allowed_names}
else:
# Empty intersection - remove all properties
properties = {}
required = set()

return properties, required


def _filter_incompatible_properties(
properties: dict[str, JsonSchema],
required: set[str],
members_properties: list[dict[str, JsonSchema]],
members_additional_props: list[Any],
) -> tuple[dict[str, JsonSchema], set[str]]:
"""Filter incompatible properties based on additionalProperties constraints."""
if not properties:
return properties, required

incompatible_props: set[str] = set()

for prop_name, prop_schema in properties.items():
prop_types = _get_type_set(prop_schema)

# Check compatibility with each member (including base)
for member_props, member_additional in zip(members_properties, members_additional_props):
if prop_name in member_props:
# Property explicitly defined - check type compatibility
member_prop_types = _get_type_set(member_props[prop_name])
if prop_types and member_prop_types and not prop_types & member_prop_types:
incompatible_props.add(prop_name)
break
continue # Compatible, check next member
if isinstance(member_additional, dict):
allowed_types = _get_type_set(cast(JsonSchema, member_additional))
# Property type must be a subset of allowed types
if prop_types and allowed_types and not (prop_types <= allowed_types):
incompatible_props.add(prop_name)
break

if incompatible_props:
allowed_names = {k for k in properties.keys() if k not in incompatible_props}
properties = {k: v for k, v in properties.items() if k in allowed_names}
required = {r for r in required if r in allowed_names}

return properties, required


def _process_result_nested_schemas(result: JsonSchema) -> None:
"""Recursively process nested schemas in the result (additionalProperties, patternProperties, items)."""
if isinstance(result.get('additionalProperties'), dict):
result['additionalProperties'] = _recurse_flatten_allof(cast(JsonSchema, result['additionalProperties']))
if isinstance(result.get('patternProperties'), dict):
result['patternProperties'] = {
k: _recurse_flatten_allof(cast(JsonSchema, v))
for k, v in result['patternProperties'].items()
if isinstance(v, dict)
}
if isinstance(result.get('items'), dict):
result['items'] = _recurse_flatten_allof(cast(JsonSchema, result['items']))


def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema:
"""Recursively flatten allOf in a JSON schema.

This function:
1. Makes a deep copy of the schema
2. Flattens allOf at the current level
3. Recursively processes nested schemas (properties, items, etc.)
"""
s = deepcopy(schema)

# Case 1: No allOf - process nested schemas recursively and return
allof = s.get('allOf')
if not isinstance(allof, list) or not allof:
return _process_nested_schemas_without_allof(s)

# Check all members are dicts
members = cast(list[JsonSchema], allof)
if not all(isinstance(m, dict) for m in members):
return s

# Check all members are object-like (can be merged)
def _is_object_like(member: JsonSchema) -> bool:
member_type = member.get('type')
if member_type is None:
# No type but has object-like keys
keys = ('properties', 'additionalProperties', 'patternProperties')
return bool(any(k in member for k in keys))
return isinstance(member_type, str) and member_type == 'object'

if not all(_is_object_like(m) for m in members):
return s

# Recursively flatten each member first
processed_members = [_recurse_flatten_allof(m) for m in members]
result: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'}
result['type'] = 'object'

# Collect data from base schema and members
base_properties = (
cast(dict[str, JsonSchema], result.get('properties', {})) if isinstance(result.get('properties'), dict) else {}
)
base_additional = result.get('additionalProperties')

properties, required, pattern_properties, additional_values, restricted_property_sets = _collect_base_schema_data(
result
)

# Then merge properties from all members
members_properties: list[dict[str, JsonSchema]] = [base_properties]
members_additional_props: list[Any] = [base_additional]

_collect_member_data(
processed_members,
properties,
required,
pattern_properties,
additional_values,
restricted_property_sets,
members_properties,
members_additional_props,
)

# Filter by restricted property sets and incompatible properties
properties, required = _filter_by_restricted_property_sets(properties, required, restricted_property_sets)
properties, required = _filter_incompatible_properties(
properties, required, members_properties, members_additional_props
)

# Apply filtered properties
if properties:
# Recursively flatten nested properties
result['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()}
if required:
result['required'] = sorted(required)
if pattern_properties:
result['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()}

# Merge additionalProperties
if additional_values:
# If any is False, result is False (most restrictive)
if any(v is False for v in additional_values):
result['additionalProperties'] = False
# If there's exactly one dict schema, preserve it
elif len(additional_values) == 1 and isinstance(additional_values[0], dict):
result['additionalProperties'] = additional_values[0]
# If any is a dict schema (multiple), result is True (can't merge multiple schemas)
elif any(isinstance(v, dict) for v in additional_values):
result['additionalProperties'] = True
# Otherwise, default to True
else:
result['additionalProperties'] = True

# Recursively process nested schemas (additionalProperties, patternProperties)
# Note: items is only valid for array types, not object types, so result.get('items') should never
# be present when result['type'] == 'object'. However, we keep this check for robustness.
_process_result_nested_schemas(result)

return result
9 changes: 7 additions & 2 deletions pydantic_ai_slim/pydantic_ai/profiles/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class OpenAIJsonSchemaTransformer(JsonSchemaTransformer):
"""

def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
super().__init__(schema, strict=strict)
super().__init__(schema, strict=strict, flatten_allof=True)
self.root_ref = schema.get('$ref')

def walk(self) -> JsonSchema:
Expand All @@ -157,7 +157,12 @@ def walk(self) -> JsonSchema:
if self.root_ref is not None:
result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
result.update(self.defs.get(root_key) or {})
# Use the transformed schema from $defs, not the original self.defs
if '$defs' in result and root_key in result['$defs']:
result.update(result['$defs'][root_key])
else:
# Fallback to original if transformed version not available (shouldn't happen in normal flow)
result.update(self.defs.get(root_key) or {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the comment I had on this before; I don't think it's been addressed or at least I don't understand why we still have these 2 paths instead of standardizing in only using self.defs


return result

Expand Down
Loading