Skip to content

Commit 27192da

Browse files
committed
Updated failing tests and flatten logic
1 parent cea2b10 commit 27192da

File tree

3 files changed

+645
-68
lines changed

3 files changed

+645
-68
lines changed

pydantic_ai_slim/pydantic_ai/_json_schema.py

Lines changed: 157 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
JsonSchema = dict[str, Any]
1212

13-
__all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer', 'flatten_allof']
13+
__all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer']
1414

1515

1616
@dataclass(init=False)
@@ -81,7 +81,7 @@ def walk(self) -> JsonSchema:
8181
def _handle(self, schema: JsonSchema) -> JsonSchema:
8282
# Flatten allOf if requested, before processing the schema
8383
if self.flatten_allof:
84-
schema = flatten_allof(schema)
84+
schema = _recurse_flatten_allof(schema)
8585

8686
nested_refs = 0
8787
if self.prefer_inlined_defs:
@@ -209,9 +209,123 @@ def _allof_is_object_like(member: JsonSchema) -> bool:
209209

210210

211211
def _merge_additional_properties_values(values: list[Any]) -> bool | JsonSchema:
212+
# If any value is False, return False (most restrictive)
213+
if any(v is False for v in values):
214+
return False
215+
# If any value is a dict schema, we can't easily merge multiple schemas, so return True
212216
if any(isinstance(v, dict) for v in values):
213217
return True
214-
return False if values and all(v is False for v in values) else True
218+
# Default to True (allow additional properties)
219+
return True
220+
221+
222+
def _collect_member_data(
223+
processed_members: list[JsonSchema],
224+
) -> tuple[
225+
dict[str, JsonSchema],
226+
set[str],
227+
dict[str, JsonSchema],
228+
list[Any],
229+
list[set[str]],
230+
list[dict[str, JsonSchema]],
231+
list[Any],
232+
]:
233+
"""Collect properties, required, patternProperties, and additionalProperties from all members."""
234+
properties: dict[str, JsonSchema] = {}
235+
required: set[str] = set()
236+
pattern_properties: dict[str, JsonSchema] = {}
237+
additional_values: list[Any] = []
238+
restricted_property_sets: list[set[str]] = []
239+
members_properties: list[dict[str, JsonSchema]] = []
240+
members_additional_props: list[Any] = []
241+
242+
for m in processed_members:
243+
member_properties_raw = m.get('properties')
244+
member_properties: dict[str, JsonSchema] = (
245+
cast(dict[str, JsonSchema], member_properties_raw) if isinstance(member_properties_raw, dict) else {}
246+
)
247+
members_properties.append(member_properties)
248+
members_additional_props.append(m.get('additionalProperties'))
249+
250+
if member_properties:
251+
properties.update(member_properties)
252+
if isinstance(m.get('required'), list):
253+
required.update(m['required'])
254+
if isinstance(m.get('patternProperties'), dict):
255+
pattern_properties.update(m['patternProperties'])
256+
if 'additionalProperties' in m:
257+
additional_values.append(m['additionalProperties'])
258+
if m['additionalProperties'] is False:
259+
restricted_property_sets.append(set(member_properties.keys()))
260+
261+
return (
262+
properties,
263+
required,
264+
pattern_properties,
265+
additional_values,
266+
restricted_property_sets,
267+
members_properties,
268+
members_additional_props,
269+
)
270+
271+
272+
def _filter_by_restricted_property_sets(merged: JsonSchema, restricted_property_sets: list[set[str]]) -> None:
273+
"""Filter properties to only those allowed by all members with additionalProperties: False."""
274+
if not restricted_property_sets:
275+
return
276+
277+
allowed_names = restricted_property_sets[0].copy()
278+
for prop_set in restricted_property_sets[1:]:
279+
allowed_names &= prop_set
280+
281+
if 'properties' in merged:
282+
merged['properties'] = {k: v for k, v in merged['properties'].items() if k in allowed_names}
283+
if not merged['properties']:
284+
merged.pop('properties')
285+
if 'required' in merged:
286+
merged['required'] = [k for k in merged['required'] if k in allowed_names]
287+
if not merged['required']:
288+
merged.pop('required')
289+
290+
291+
def _filter_incompatible_properties(
292+
merged: JsonSchema,
293+
members_properties: list[dict[str, JsonSchema]],
294+
members_additional_props: list[Any],
295+
) -> None:
296+
"""Filter properties that are incompatible with additionalProperties constraints."""
297+
if 'properties' not in merged:
298+
return
299+
300+
incompatible_props: set[str] = set()
301+
for prop_name, prop_schema in merged['properties'].items():
302+
prop_types = _get_type_set(prop_schema)
303+
for member_props, member_additional in zip(members_properties, members_additional_props):
304+
if prop_name in member_props:
305+
member_prop_types = _get_type_set(member_props[prop_name])
306+
if prop_types and member_prop_types and not prop_types & member_prop_types:
307+
incompatible_props.add(prop_name)
308+
break
309+
continue
310+
311+
if member_additional is False:
312+
incompatible_props.add(prop_name)
313+
break
314+
315+
if isinstance(member_additional, dict):
316+
allowed_types = _get_type_set(cast(JsonSchema, member_additional))
317+
if prop_types and allowed_types and not prop_types <= allowed_types:
318+
incompatible_props.add(prop_name)
319+
break
320+
321+
if incompatible_props:
322+
merged['properties'] = {k: v for k, v in merged['properties'].items() if k not in incompatible_props}
323+
if not merged['properties']:
324+
merged.pop('properties')
325+
if 'required' in merged:
326+
merged['required'] = [k for k in merged['required'] if k not in incompatible_props]
327+
if not merged['required']:
328+
merged.pop('required')
215329

216330

217331
def _flatten_current_level(s: JsonSchema) -> JsonSchema:
@@ -230,37 +344,66 @@ def _flatten_current_level(s: JsonSchema) -> JsonSchema:
230344
merged: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'}
231345
merged['type'] = 'object'
232346

347+
# Collect initial properties from merged schema
233348
properties: dict[str, JsonSchema] = {}
234349
if isinstance(merged.get('properties'), dict):
235350
properties.update(merged['properties'])
236351

237352
required: set[str] = set(merged.get('required', []) or [])
238353
pattern_properties: dict[str, JsonSchema] = dict(merged.get('patternProperties', {}) or {})
239-
additional_values: list[Any] = []
240-
241-
for m in processed_members:
242-
if isinstance(m.get('properties'), dict):
243-
properties.update(m['properties'])
244-
if isinstance(m.get('required'), list):
245-
required.update(m['required'])
246-
if isinstance(m.get('patternProperties'), dict):
247-
pattern_properties.update(m['patternProperties'])
248-
if 'additionalProperties' in m:
249-
additional_values.append(m['additionalProperties'])
250354

355+
# Collect data from all members
356+
(
357+
member_properties,
358+
member_required,
359+
member_pattern_properties,
360+
additional_values,
361+
restricted_property_sets,
362+
members_properties,
363+
members_additional_props,
364+
) = _collect_member_data(processed_members)
365+
366+
# Merge all collected data
367+
properties.update(member_properties)
368+
required.update(member_required)
369+
pattern_properties.update(member_pattern_properties)
370+
371+
# Apply merged properties, required, and patternProperties
251372
if properties:
252373
merged['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()}
253374
if required:
254375
merged['required'] = sorted(required)
255376
if pattern_properties:
256377
merged['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()}
257378

379+
# Filter by restricted property sets (additionalProperties: False)
380+
_filter_by_restricted_property_sets(merged, restricted_property_sets)
381+
382+
# Merge additionalProperties
258383
if additional_values:
259384
merged['additionalProperties'] = _merge_additional_properties_values(additional_values)
260385

386+
# Filter incompatible properties based on additionalProperties constraints
387+
_filter_incompatible_properties(merged, members_properties, members_additional_props)
388+
261389
return merged
262390

263391

392+
def _get_type_set(schema: JsonSchema | None) -> set[str] | None:
393+
if not schema:
394+
return None
395+
schema_type = schema.get('type')
396+
if isinstance(schema_type, list):
397+
result: set[str] = set()
398+
type_list: list[Any] = cast(list[Any], schema_type)
399+
for t in type_list:
400+
result.add(str(t))
401+
return result
402+
if isinstance(schema_type, str):
403+
return {schema_type}
404+
return None
405+
406+
264407
def _recurse_children(s: JsonSchema) -> JsonSchema:
265408
t = s.get('type')
266409
if t == 'object':
@@ -292,14 +435,3 @@ def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema:
292435
s = _flatten_current_level(s)
293436
s = _recurse_children(s)
294437
return s
295-
296-
297-
def flatten_allof(schema: JsonSchema) -> JsonSchema:
298-
"""Flatten simple object-only allOf combinations by merging object members.
299-
300-
- Merges properties and unions required lists.
301-
- Combines additionalProperties conservatively: only False if all are False; otherwise True.
302-
- Recurses into nested object/array members.
303-
- Leaves non-object allOfs untouched.
304-
"""
305-
return _recurse_flatten_allof(schema)

tests/models/test_openai.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,27 +1936,69 @@ def test_openai_transformer_with_recursive_ref() -> None:
19361936

19371937
# The transformer should resolve the $ref and use the transformed schema from $defs
19381938
# (not the original self.defs, which was the bug we fixed)
1939-
assert isinstance(result, dict)
1940-
# In strict mode, all properties should be required
1941-
assert 'properties' in result
1942-
assert 'required' in result
1943-
# The transformed schema should have strict mode applied (additionalProperties: False)
1944-
assert result.get('additionalProperties') is False
1945-
# All properties should be in required list (strict mode requirement)
1946-
assert 'foo' in result['required']
1939+
assert result == snapshot(
1940+
{
1941+
'$defs': {
1942+
'MyModel': {
1943+
'type': 'object',
1944+
'properties': {'foo': {'type': 'string'}},
1945+
'required': ['foo'],
1946+
'additionalProperties': False,
1947+
}
1948+
},
1949+
'type': 'object',
1950+
'properties': {'foo': {'type': 'string'}},
1951+
'required': ['foo'],
1952+
'additionalProperties': False,
1953+
}
1954+
)
19471955

19481956

19491957
def test_openai_transformer_fallback_when_defs_missing() -> None:
19501958
"""Test fallback path when root_key is not in result['$defs'] (line 165).
19511959
19521960
This tests the safety net fallback that shouldn't happen in normal flow.
19531961
The fallback uses self.defs (original schema) when the transformed $defs
1954-
doesn't contain the root_key. This edge case is simulated using a mock.
1955-
"""
1956-
from unittest.mock import patch
1962+
doesn't contain the root_key.
19571963
1964+
We test this by creating a custom transformer subclass that overrides the base
1965+
class walk() to return a result without the root_key, allowing us to test the
1966+
actual fallback code path in OpenAIJsonSchemaTransformer.walk().
1967+
"""
1968+
from pydantic_ai._json_schema import JsonSchema, JsonSchemaTransformer
19581969
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer
19591970

1971+
# Create a custom transformer that overrides walk() to intercept super().walk()
1972+
# and manipulate the result, then call the actual OpenAIJsonSchemaTransformer.walk()
1973+
# logic which will hit the fallback path
1974+
class TestTransformer(OpenAIJsonSchemaTransformer):
1975+
def __init__(self, schema: dict[str, Any], *, strict: bool | None = None):
1976+
JsonSchemaTransformer.__init__(self, schema, strict=strict, flatten_allof=True)
1977+
self.root_ref = schema.get('$ref')
1978+
1979+
def walk(self) -> JsonSchema:
1980+
# Store the original base class walk method
1981+
base_walk = JsonSchemaTransformer.walk
1982+
1983+
# Create a wrapper that manipulates the result to simulate the edge case
1984+
def wrapped_walk(self_instance: TestTransformer) -> JsonSchema:
1985+
result = base_walk(self_instance)
1986+
# Remove root_key from $defs to simulate the edge case
1987+
if '$defs' in result and 'MyModel' in result.get('$defs', {}):
1988+
result['$defs'].pop('MyModel')
1989+
return result
1990+
1991+
# Temporarily replace the base class walk method for this instance
1992+
JsonSchemaTransformer.walk = wrapped_walk # type: ignore[assignment]
1993+
try:
1994+
# Now call the actual OpenAIJsonSchemaTransformer.walk() method
1995+
# which will call super().walk() (our wrapped version) and then
1996+
# hit the fallback path at lines 164-165
1997+
return super().walk()
1998+
finally:
1999+
# Restore the original method
2000+
JsonSchemaTransformer.walk = base_walk
2001+
19602002
schema: dict[str, Any] = {
19612003
'$ref': '#/$defs/MyModel',
19622004
'$defs': {
@@ -1968,19 +2010,17 @@ def test_openai_transformer_fallback_when_defs_missing() -> None:
19682010
},
19692011
}
19702012

1971-
transformer = OpenAIJsonSchemaTransformer(schema, strict=True)
2013+
transformer = TestTransformer(schema, strict=True)
2014+
# Call walk() which will execute the actual OpenAIJsonSchemaTransformer.walk() logic
2015+
# and hit the fallback path at line 164-165
2016+
result = transformer.walk()
19722017

1973-
# Simulate edge case: super().walk() returns $defs without root_key
1974-
# This shouldn't happen in normal flow, but we test the fallback path
1975-
with patch.object(
1976-
transformer.__class__.__bases__[0],
1977-
'walk',
1978-
return_value={'$defs': {'OtherModel': {'type': 'object'}}},
1979-
):
1980-
result = transformer.walk()
1981-
# Fallback should use self.defs.get(root_key) which contains MyModel
1982-
assert isinstance(result, dict)
1983-
assert 'properties' in result or 'type' in result
2018+
# Verify the fallback worked: result should have MyModel's properties from self.defs
2019+
# Note: The fallback uses the original, untransformed schema, so it won't have
2020+
# additionalProperties: False applied (that transformation happens in transform())
2021+
assert result == snapshot(
2022+
{'$defs': {}, 'type': 'object', 'properties': {'foo': {'type': 'string'}}, 'required': ['foo']}
2023+
)
19842024

19852025

19862026
def test_openai_transformer_flattens_allof() -> None:

0 commit comments

Comments
 (0)