-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Flatten allOf properties for OpenAI compatibility #3451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
afddade
8ff4db7
e4aebaf
e5815fd
b1058e9
1244bfd
f143029
eaa685e
cea2b10
27192da
756bdc4
6b47179
6114437
16e33e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,12 +4,14 @@ | |
| from abc import ABC, abstractmethod | ||
| from copy import deepcopy | ||
| from dataclasses import dataclass | ||
| from typing import Any, Literal | ||
| from typing import Any, Literal, cast | ||
|
|
||
| from .exceptions import UserError | ||
|
|
||
| JsonSchema = dict[str, Any] | ||
|
|
||
| __all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer', 'flatten_allof'] | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class JsonSchemaTransformer(ABC): | ||
|
|
@@ -26,14 +28,18 @@ def __init__( | |
| strict: bool | None = None, | ||
| prefer_inlined_defs: bool = False, | ||
| simplify_nullable_unions: bool = False, | ||
| flatten_allof: bool = False, | ||
| ): | ||
| self.schema = schema | ||
|
|
||
| self.strict = strict | ||
| self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly | ||
| # Can be set to False by subclasses to set `strict` on `ToolDefinition` | ||
| # when not set explicitly by the user. | ||
| self.is_strict_compatible = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this change, it's now ambiguous which line the comment belongs to :) Can you move it back please, even though the line is uncomfortable long? Or add a blank line between the 2 vars? |
||
|
|
||
| self.prefer_inlined_defs = prefer_inlined_defs | ||
| self.simplify_nullable_unions = simplify_nullable_unions | ||
| self.flatten_allof = flatten_allof | ||
|
|
||
| self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {}) | ||
| self.refs_stack: list[str] = [] | ||
|
|
@@ -73,6 +79,10 @@ def walk(self) -> JsonSchema: | |
| return handled | ||
|
|
||
| def _handle(self, schema: JsonSchema) -> JsonSchema: | ||
| # Flatten allOf if requested, before processing the schema | ||
| if self.flatten_allof: | ||
| schema = flatten_allof(schema) | ||
|
|
||
| nested_refs = 0 | ||
| if self.prefer_inlined_defs: | ||
| while ref := schema.get('$ref'): | ||
|
|
@@ -188,3 +198,108 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | |
|
|
||
| def transform(self, schema: JsonSchema) -> JsonSchema: | ||
| return schema | ||
|
|
||
|
|
||
| def _allof_is_object_like(member: JsonSchema) -> bool: | ||
| member_type = member.get('type') | ||
| if member_type is None: | ||
| keys = ('properties', 'additionalProperties', 'patternProperties') | ||
| return bool(any(k in member for k in keys)) | ||
| return member_type == 'object' | ||
|
|
||
|
|
||
| def _merge_additional_properties_values(values: list[Any]) -> bool | JsonSchema: | ||
| if any(isinstance(v, dict) for v in values): | ||
| return True | ||
| return False if values and all(v is False for v in values) else True | ||
|
|
||
|
|
||
| def _flatten_current_level(s: JsonSchema) -> JsonSchema: | ||
| raw_members = s.get('allOf') | ||
| if not isinstance(raw_members, list) or not raw_members: | ||
| return s | ||
|
|
||
| members = cast(list[JsonSchema], raw_members) | ||
| for raw in members: | ||
| if not isinstance(raw, dict): | ||
| return s | ||
| if not all(_allof_is_object_like(member) for member in members): | ||
| return s | ||
|
|
||
| processed_members = [_recurse_flatten_allof(member) for member in members] | ||
| merged: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'} | ||
| merged['type'] = 'object' | ||
|
|
||
| properties: dict[str, JsonSchema] = {} | ||
| if isinstance(merged.get('properties'), dict): | ||
| properties.update(merged['properties']) | ||
|
|
||
| required: set[str] = set(merged.get('required', []) or []) | ||
| pattern_properties: dict[str, JsonSchema] = dict(merged.get('patternProperties', {}) or {}) | ||
| additional_values: list[Any] = [] | ||
|
|
||
| for m in processed_members: | ||
| if isinstance(m.get('properties'), dict): | ||
| properties.update(m['properties']) | ||
| if isinstance(m.get('required'), list): | ||
| required.update(m['required']) | ||
| if isinstance(m.get('patternProperties'), dict): | ||
| pattern_properties.update(m['patternProperties']) | ||
| if 'additionalProperties' in m: | ||
| additional_values.append(m['additionalProperties']) | ||
|
|
||
| if properties: | ||
| merged['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()} | ||
| if required: | ||
| merged['required'] = sorted(required) | ||
| if pattern_properties: | ||
| merged['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()} | ||
|
|
||
| if additional_values: | ||
| merged['additionalProperties'] = _merge_additional_properties_values(additional_values) | ||
|
|
||
| return merged | ||
|
|
||
|
|
||
| def _recurse_children(s: JsonSchema) -> JsonSchema: | ||
| t = s.get('type') | ||
| if t == 'object': | ||
| if isinstance(s.get('properties'), dict): | ||
| s['properties'] = { | ||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||
| for k, v in s['properties'].items() | ||
| if isinstance(v, dict) | ||
| } | ||
| ap = s.get('additionalProperties') | ||
| if isinstance(ap, dict): | ||
| ap_schema = cast(JsonSchema, ap) | ||
| s['additionalProperties'] = _recurse_flatten_allof(ap_schema) | ||
| if isinstance(s.get('patternProperties'), dict): | ||
| s['patternProperties'] = { | ||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||
| for k, v in s['patternProperties'].items() | ||
| if isinstance(v, dict) | ||
| } | ||
| elif t == 'array': | ||
| items = s.get('items') | ||
| if isinstance(items, dict): | ||
| s['items'] = _recurse_flatten_allof(cast(JsonSchema, items)) | ||
| return s | ||
|
|
||
|
|
||
| def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema: | ||
| s = deepcopy(schema) | ||
| s = _flatten_current_level(s) | ||
| s = _recurse_children(s) | ||
| return s | ||
|
|
||
|
|
||
| def flatten_allof(schema: JsonSchema) -> JsonSchema: | ||
|
||
| """Flatten simple object-only allOf combinations by merging object members. | ||
|
|
||
| - Merges properties and unions required lists. | ||
| - Combines additionalProperties conservatively: only False if all are False; otherwise True. | ||
| - Recurses into nested object/array members. | ||
| - Leaves non-object allOfs untouched. | ||
| """ | ||
| return _recurse_flatten_allof(schema) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,7 +143,7 @@ class OpenAIJsonSchemaTransformer(JsonSchemaTransformer): | |
| """ | ||
|
|
||
| def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | ||
| super().__init__(schema, strict=strict) | ||
| super().__init__(schema, strict=strict, flatten_allof=True) | ||
| self.root_ref = schema.get('$ref') | ||
|
|
||
| def walk(self) -> JsonSchema: | ||
|
|
@@ -157,7 +157,12 @@ def walk(self) -> JsonSchema: | |
| if self.root_ref is not None: | ||
| result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method | ||
| root_key = re.sub(r'^#/\$defs/', '', self.root_ref) | ||
| result.update(self.defs.get(root_key) or {}) | ||
| # Use the transformed schema from $defs, not the original self.defs | ||
| if '$defs' in result and root_key in result['$defs']: | ||
| result.update(result['$defs'][root_key]) | ||
| else: | ||
| # Fallback to original if transformed version not available (shouldn't happen in normal flow) | ||
| result.update(self.defs.get(root_key) or {}) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check the comment I had on this before; I don't think it's been addressed or at least I don't understand why we still have these 2 paths instead of standardizing in only using |
||
|
|
||
| return result | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1915,6 +1915,111 @@ class MyModel(BaseModel): | |||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def test_openai_transformer_with_recursive_ref() -> None: | ||||||||||||||||||||
| """Test that OpenAIJsonSchemaTransformer correctly handles recursive models with $ref root.""" | ||||||||||||||||||||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Create a schema with $ref root (recursive model scenario) | ||||||||||||||||||||
| schema: dict[str, Any] = { | ||||||||||||||||||||
| '$ref': '#/$defs/MyModel', | ||||||||||||||||||||
| '$defs': { | ||||||||||||||||||||
| 'MyModel': { | ||||||||||||||||||||
| 'type': 'object', | ||||||||||||||||||||
| 'properties': {'foo': {'type': 'string'}}, | ||||||||||||||||||||
| 'required': ['foo'], | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| transformer = OpenAIJsonSchemaTransformer(schema, strict=True) | ||||||||||||||||||||
| result = transformer.walk() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # The transformer should resolve the $ref and use the transformed schema from $defs | ||||||||||||||||||||
| # (not the original self.defs, which was the bug we fixed) | ||||||||||||||||||||
| assert isinstance(result, dict) | ||||||||||||||||||||
|
||||||||||||||||||||
| # In strict mode, all properties should be required | ||||||||||||||||||||
| assert 'properties' in result | ||||||||||||||||||||
| assert 'required' in result | ||||||||||||||||||||
| # The transformed schema should have strict mode applied (additionalProperties: False) | ||||||||||||||||||||
| assert result.get('additionalProperties') is False | ||||||||||||||||||||
| # All properties should be in required list (strict mode requirement) | ||||||||||||||||||||
| assert 'foo' in result['required'] | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def test_openai_transformer_fallback_when_defs_missing() -> None: | ||||||||||||||||||||
| """Test fallback path when root_key is not in result['$defs'] (line 165). | ||||||||||||||||||||
|
|
||||||||||||||||||||
| This tests the safety net fallback that shouldn't happen in normal flow. | ||||||||||||||||||||
| The fallback uses self.defs (original schema) when the transformed $defs | ||||||||||||||||||||
| doesn't contain the root_key. This edge case is simulated using a mock. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| from unittest.mock import patch | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer | ||||||||||||||||||||
|
|
||||||||||||||||||||
| schema: dict[str, Any] = { | ||||||||||||||||||||
| '$ref': '#/$defs/MyModel', | ||||||||||||||||||||
| '$defs': { | ||||||||||||||||||||
| 'MyModel': { | ||||||||||||||||||||
| 'type': 'object', | ||||||||||||||||||||
| 'properties': {'foo': {'type': 'string'}}, | ||||||||||||||||||||
| 'required': ['foo'], | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| transformer = OpenAIJsonSchemaTransformer(schema, strict=True) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Simulate edge case: super().walk() returns $defs without root_key | ||||||||||||||||||||
| # This shouldn't happen in normal flow, but we test the fallback path | ||||||||||||||||||||
| with patch.object( | ||||||||||||||||||||
| transformer.__class__.__bases__[0], | ||||||||||||||||||||
|
||||||||||||||||||||
| schema = deepcopy(self.schema) | |
| # First, handle everything but $defs: | |
| schema.pop('$defs', None) | |
| handled = self._handle(schema) | |
| if not self.prefer_inlined_defs and self.defs: | |
| handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()} | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make
flatten_allofprivate and not export itThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done