-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Flatten allOf properties for OpenAI compatibility #3451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
afddade
8ff4db7
e4aebaf
e5815fd
b1058e9
1244bfd
f143029
eaa685e
cea2b10
27192da
756bdc4
6b47179
6114437
16e33e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,12 +4,14 @@ | |||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||
| from copy import deepcopy | ||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||
| from typing import Any, Literal | ||||||||||||||||||
| from typing import Any, Literal, cast | ||||||||||||||||||
|
|
||||||||||||||||||
| from .exceptions import UserError | ||||||||||||||||||
|
|
||||||||||||||||||
| JsonSchema = dict[str, Any] | ||||||||||||||||||
|
|
||||||||||||||||||
| __all__ = ['JsonSchemaTransformer', 'InlineDefsJsonSchemaTransformer'] | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @dataclass(init=False) | ||||||||||||||||||
| class JsonSchemaTransformer(ABC): | ||||||||||||||||||
|
|
@@ -26,14 +28,18 @@ def __init__( | |||||||||||||||||
| strict: bool | None = None, | ||||||||||||||||||
| prefer_inlined_defs: bool = False, | ||||||||||||||||||
| simplify_nullable_unions: bool = False, # TODO (v2): Remove this, no longer used | ||||||||||||||||||
| flatten_allof: bool = False, | ||||||||||||||||||
| ): | ||||||||||||||||||
| self.schema = schema | ||||||||||||||||||
|
|
||||||||||||||||||
| self.strict = strict | ||||||||||||||||||
| self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly | ||||||||||||||||||
| # Can be set to False by subclasses to set `strict` on `ToolDefinition` | ||||||||||||||||||
| # when not set explicitly by the user. | ||||||||||||||||||
| self.is_strict_compatible = True | ||||||||||||||||||
|
|
||||||||||||||||||
| self.prefer_inlined_defs = prefer_inlined_defs | ||||||||||||||||||
| self.simplify_nullable_unions = simplify_nullable_unions | ||||||||||||||||||
| self.flatten_allof = flatten_allof | ||||||||||||||||||
|
|
||||||||||||||||||
| self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {}) | ||||||||||||||||||
| self.refs_stack: list[str] = [] | ||||||||||||||||||
|
|
@@ -73,6 +79,10 @@ def walk(self) -> JsonSchema: | |||||||||||||||||
| return handled | ||||||||||||||||||
|
|
||||||||||||||||||
| def _handle(self, schema: JsonSchema) -> JsonSchema: | ||||||||||||||||||
| # Flatten allOf if requested, before processing the schema | ||||||||||||||||||
| if self.flatten_allof: | ||||||||||||||||||
| schema = _recurse_flatten_allof(schema) | ||||||||||||||||||
|
|
||||||||||||||||||
| nested_refs = 0 | ||||||||||||||||||
| if self.prefer_inlined_defs: | ||||||||||||||||||
| while ref := schema.get('$ref'): | ||||||||||||||||||
|
|
@@ -187,3 +197,274 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | |||||||||||||||||
|
|
||||||||||||||||||
| def transform(self, schema: JsonSchema) -> JsonSchema: | ||||||||||||||||||
| return schema | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _get_type_set(schema: JsonSchema) -> set[str] | None: | ||||||||||||||||||
| """Extract type(s) from a schema as a set of strings.""" | ||||||||||||||||||
| schema_type = schema.get('type') | ||||||||||||||||||
| if isinstance(schema_type, list): | ||||||||||||||||||
| return {str(t) for t in cast(list[Any], schema_type)} | ||||||||||||||||||
| if isinstance(schema_type, str): | ||||||||||||||||||
| return {schema_type} | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _process_nested_schemas_without_allof(s: JsonSchema) -> JsonSchema: | ||||||||||||||||||
| """Process nested schemas recursively when there is no allOf at the current level.""" | ||||||||||||||||||
| schema_type = s.get('type') | ||||||||||||||||||
| if schema_type == 'object': | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have pretty similar code here: pydantic-ai/pydantic_ai_slim/pydantic_ai/_json_schema.py Lines 94 to 101 in 1b576dd
Would it be an option to merge this new processing logic into that one, somehow? Or at least make them less seemingly-duplicative? I just don't love the idea of recursively walking the JSON schema twice with different implementations. |
||||||||||||||||||
| if isinstance(s.get('properties'), dict): | ||||||||||||||||||
| s['properties'] = { | ||||||||||||||||||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||||||||||||||||||
| for k, v in s['properties'].items() | ||||||||||||||||||
| if isinstance(v, dict) | ||||||||||||||||||
| } | ||||||||||||||||||
| if isinstance(s.get('additionalProperties'), dict): | ||||||||||||||||||
| s['additionalProperties'] = _recurse_flatten_allof(cast(JsonSchema, s['additionalProperties'])) | ||||||||||||||||||
| if isinstance(s.get('patternProperties'), dict): | ||||||||||||||||||
| s['patternProperties'] = { | ||||||||||||||||||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||||||||||||||||||
| for k, v in s['patternProperties'].items() | ||||||||||||||||||
| if isinstance(v, dict) | ||||||||||||||||||
| } | ||||||||||||||||||
| elif schema_type == 'array': | ||||||||||||||||||
| if isinstance(s.get('items'), dict): | ||||||||||||||||||
| s['items'] = _recurse_flatten_allof(cast(JsonSchema, s['items'])) | ||||||||||||||||||
| return s | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _collect_base_schema_data( | ||||||||||||||||||
| result: JsonSchema, | ||||||||||||||||||
| ) -> tuple[dict[str, JsonSchema], set[str], dict[str, JsonSchema], list[Any], list[set[str]]]: | ||||||||||||||||||
| """Collect data from base schema: properties, required, patternProperties, additionalProperties.""" | ||||||||||||||||||
| properties: dict[str, JsonSchema] = {} | ||||||||||||||||||
| required: set[str] = set() | ||||||||||||||||||
| pattern_properties: dict[str, JsonSchema] = {} | ||||||||||||||||||
| additional_values: list[Any] = [] | ||||||||||||||||||
| restricted_property_sets: list[set[str]] = [] | ||||||||||||||||||
|
|
||||||||||||||||||
| base_properties = ( | ||||||||||||||||||
| cast(dict[str, JsonSchema], result.get('properties', {})) if isinstance(result.get('properties'), dict) else {} | ||||||||||||||||||
| ) | ||||||||||||||||||
| base_additional = result.get('additionalProperties') | ||||||||||||||||||
|
|
||||||||||||||||||
| if base_properties: | ||||||||||||||||||
| properties.update(base_properties) | ||||||||||||||||||
| if isinstance(result.get('required'), list): | ||||||||||||||||||
| required.update(result['required']) | ||||||||||||||||||
| if isinstance(result.get('patternProperties'), dict): | ||||||||||||||||||
| pattern_properties.update(result['patternProperties']) | ||||||||||||||||||
| if base_additional is False: | ||||||||||||||||||
| additional_values.append(False) | ||||||||||||||||||
| # Only restrict if base schema has properties; if base has no properties but additionalProperties: False, | ||||||||||||||||||
| # it means no additional properties are allowed, but properties from allOf members are still valid | ||||||||||||||||||
| if base_properties: | ||||||||||||||||||
| restricted_property_sets.append(set(base_properties.keys())) | ||||||||||||||||||
|
|
||||||||||||||||||
| return properties, required, pattern_properties, additional_values, restricted_property_sets | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _collect_member_data( | ||||||||||||||||||
| processed_members: list[JsonSchema], | ||||||||||||||||||
| properties: dict[str, JsonSchema], | ||||||||||||||||||
| required: set[str], | ||||||||||||||||||
| pattern_properties: dict[str, JsonSchema], | ||||||||||||||||||
| additional_values: list[Any], | ||||||||||||||||||
| restricted_property_sets: list[set[str]], | ||||||||||||||||||
| members_properties: list[dict[str, JsonSchema]], | ||||||||||||||||||
| members_additional_props: list[Any], | ||||||||||||||||||
| ) -> None: | ||||||||||||||||||
| """Collect data from allOf members and update the collections.""" | ||||||||||||||||||
| for m in processed_members: | ||||||||||||||||||
| member_props = ( | ||||||||||||||||||
| cast(dict[str, JsonSchema], m.get('properties', {})) if isinstance(m.get('properties'), dict) else {} | ||||||||||||||||||
| ) | ||||||||||||||||||
| members_properties.append(member_props) | ||||||||||||||||||
| members_additional_props.append(m.get('additionalProperties')) | ||||||||||||||||||
|
|
||||||||||||||||||
| if member_props: | ||||||||||||||||||
| properties.update(member_props) | ||||||||||||||||||
| if isinstance(m.get('required'), list): | ||||||||||||||||||
| required.update(m['required']) | ||||||||||||||||||
| if isinstance(m.get('patternProperties'), dict): | ||||||||||||||||||
| pattern_properties.update(m['patternProperties']) | ||||||||||||||||||
| if 'additionalProperties' in m: | ||||||||||||||||||
| additional_values.append(m['additionalProperties']) | ||||||||||||||||||
| if m['additionalProperties'] is False: | ||||||||||||||||||
| restricted_property_sets.append(set(member_props.keys())) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _filter_by_restricted_property_sets( | ||||||||||||||||||
| properties: dict[str, JsonSchema], required: set[str], restricted_property_sets: list[set[str]] | ||||||||||||||||||
| ) -> tuple[dict[str, JsonSchema], set[str]]: | ||||||||||||||||||
| """Filter properties and required by restricted property sets (intersection when some/all have additionalProperties: False).""" | ||||||||||||||||||
| if not restricted_property_sets: | ||||||||||||||||||
| return properties, required | ||||||||||||||||||
|
|
||||||||||||||||||
| # Intersection of allowed properties from all members with additionalProperties: False | ||||||||||||||||||
| allowed_names = restricted_property_sets[0].copy() | ||||||||||||||||||
| for prop_set in restricted_property_sets[1:]: | ||||||||||||||||||
| allowed_names &= prop_set | ||||||||||||||||||
| # Filter properties to only include allowed names | ||||||||||||||||||
| if allowed_names: | ||||||||||||||||||
| properties = {k: v for k, v in properties.items() if k in allowed_names} | ||||||||||||||||||
| required = {r for r in required if r in allowed_names} | ||||||||||||||||||
| else: | ||||||||||||||||||
| # Empty intersection - remove all properties | ||||||||||||||||||
| properties = {} | ||||||||||||||||||
| required = set() | ||||||||||||||||||
|
|
||||||||||||||||||
| return properties, required | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _filter_incompatible_properties( | ||||||||||||||||||
| properties: dict[str, JsonSchema], | ||||||||||||||||||
| required: set[str], | ||||||||||||||||||
| members_properties: list[dict[str, JsonSchema]], | ||||||||||||||||||
| members_additional_props: list[Any], | ||||||||||||||||||
| ) -> tuple[dict[str, JsonSchema], set[str]]: | ||||||||||||||||||
| """Filter incompatible properties based on additionalProperties constraints.""" | ||||||||||||||||||
| if not properties: | ||||||||||||||||||
| return properties, required | ||||||||||||||||||
|
|
||||||||||||||||||
| incompatible_props: set[str] = set() | ||||||||||||||||||
|
|
||||||||||||||||||
| for prop_name, prop_schema in properties.items(): | ||||||||||||||||||
| prop_types = _get_type_set(prop_schema) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check compatibility with each member (including base) | ||||||||||||||||||
| for member_props, member_additional in zip(members_properties, members_additional_props): | ||||||||||||||||||
| if prop_name in member_props: | ||||||||||||||||||
| # Property explicitly defined - check type compatibility | ||||||||||||||||||
| member_prop_types = _get_type_set(member_props[prop_name]) | ||||||||||||||||||
| if prop_types and member_prop_types and not prop_types & member_prop_types: | ||||||||||||||||||
| incompatible_props.add(prop_name) | ||||||||||||||||||
| break | ||||||||||||||||||
| continue # Compatible, check next member | ||||||||||||||||||
| if isinstance(member_additional, dict): | ||||||||||||||||||
| allowed_types = _get_type_set(cast(JsonSchema, member_additional)) | ||||||||||||||||||
| # Property type must be a subset of allowed types | ||||||||||||||||||
| if prop_types and allowed_types and not (prop_types <= allowed_types): | ||||||||||||||||||
| incompatible_props.add(prop_name) | ||||||||||||||||||
| break | ||||||||||||||||||
|
|
||||||||||||||||||
| if incompatible_props: | ||||||||||||||||||
| allowed_names = {k for k in properties.keys() if k not in incompatible_props} | ||||||||||||||||||
| properties = {k: v for k, v in properties.items() if k in allowed_names} | ||||||||||||||||||
| required = {r for r in required if r in allowed_names} | ||||||||||||||||||
|
|
||||||||||||||||||
| return properties, required | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _process_result_nested_schemas(result: JsonSchema) -> None: | ||||||||||||||||||
| """Recursively process nested schemas in the result (additionalProperties, patternProperties, items).""" | ||||||||||||||||||
| if isinstance(result.get('additionalProperties'), dict): | ||||||||||||||||||
| result['additionalProperties'] = _recurse_flatten_allof(cast(JsonSchema, result['additionalProperties'])) | ||||||||||||||||||
| if isinstance(result.get('patternProperties'), dict): | ||||||||||||||||||
| result['patternProperties'] = { | ||||||||||||||||||
| k: _recurse_flatten_allof(cast(JsonSchema, v)) | ||||||||||||||||||
| for k, v in result['patternProperties'].items() | ||||||||||||||||||
| if isinstance(v, dict) | ||||||||||||||||||
| } | ||||||||||||||||||
| if isinstance(result.get('items'), dict): | ||||||||||||||||||
| result['items'] = _recurse_flatten_allof(cast(JsonSchema, result['items'])) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _recurse_flatten_allof(schema: JsonSchema) -> JsonSchema: | ||||||||||||||||||
| """Recursively flatten allOf in a JSON schema. | ||||||||||||||||||
|
|
||||||||||||||||||
| This function: | ||||||||||||||||||
| 1. Makes a deep copy of the schema | ||||||||||||||||||
| 2. Flattens allOf at the current level | ||||||||||||||||||
| 3. Recursively processes nested schemas (properties, items, etc.) | ||||||||||||||||||
| """ | ||||||||||||||||||
| s = deepcopy(schema) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Case 1: No allOf - process nested schemas recursively and return | ||||||||||||||||||
| allof = s.get('allOf') | ||||||||||||||||||
| if not isinstance(allof, list) or not allof: | ||||||||||||||||||
| return _process_nested_schemas_without_allof(s) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check all members are dicts | ||||||||||||||||||
| members = cast(list[JsonSchema], allof) | ||||||||||||||||||
| if not all(isinstance(m, dict) for m in members): | ||||||||||||||||||
| return s | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check all members are object-like (can be merged) | ||||||||||||||||||
| def _is_object_like(member: JsonSchema) -> bool: | ||||||||||||||||||
| member_type = member.get('type') | ||||||||||||||||||
| if member_type is None: | ||||||||||||||||||
| # No type but has object-like keys | ||||||||||||||||||
| keys = ('properties', 'additionalProperties', 'patternProperties') | ||||||||||||||||||
| return bool(any(k in member for k in keys)) | ||||||||||||||||||
| return isinstance(member_type, str) and member_type == 'object' | ||||||||||||||||||
|
|
||||||||||||||||||
| if not all(_is_object_like(m) for m in members): | ||||||||||||||||||
| return s | ||||||||||||||||||
|
|
||||||||||||||||||
| # Recursively flatten each member first | ||||||||||||||||||
| processed_members = [_recurse_flatten_allof(m) for m in members] | ||||||||||||||||||
| result: JsonSchema = {k: v for k, v in s.items() if k != 'allOf'} | ||||||||||||||||||
| result['type'] = 'object' | ||||||||||||||||||
|
|
||||||||||||||||||
| # Collect data from base schema and members | ||||||||||||||||||
| base_properties = ( | ||||||||||||||||||
| cast(dict[str, JsonSchema], result.get('properties', {})) if isinstance(result.get('properties'), dict) else {} | ||||||||||||||||||
| ) | ||||||||||||||||||
| base_additional = result.get('additionalProperties') | ||||||||||||||||||
|
|
||||||||||||||||||
| properties, required, pattern_properties, additional_values, restricted_property_sets = _collect_base_schema_data( | ||||||||||||||||||
| result | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Then merge properties from all members | ||||||||||||||||||
| members_properties: list[dict[str, JsonSchema]] = [base_properties] | ||||||||||||||||||
| members_additional_props: list[Any] = [base_additional] | ||||||||||||||||||
|
|
||||||||||||||||||
| _collect_member_data( | ||||||||||||||||||
| processed_members, | ||||||||||||||||||
| properties, | ||||||||||||||||||
| required, | ||||||||||||||||||
| pattern_properties, | ||||||||||||||||||
| additional_values, | ||||||||||||||||||
| restricted_property_sets, | ||||||||||||||||||
| members_properties, | ||||||||||||||||||
| members_additional_props, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Filter by restricted property sets and incompatible properties | ||||||||||||||||||
| properties, required = _filter_by_restricted_property_sets(properties, required, restricted_property_sets) | ||||||||||||||||||
| properties, required = _filter_incompatible_properties( | ||||||||||||||||||
| properties, required, members_properties, members_additional_props | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Apply filtered properties | ||||||||||||||||||
| if properties: | ||||||||||||||||||
| # Recursively flatten nested properties | ||||||||||||||||||
| result['properties'] = {k: _recurse_flatten_allof(v) for k, v in properties.items()} | ||||||||||||||||||
| if required: | ||||||||||||||||||
| result['required'] = sorted(required) | ||||||||||||||||||
| if pattern_properties: | ||||||||||||||||||
| result['patternProperties'] = {k: _recurse_flatten_allof(v) for k, v in pattern_properties.items()} | ||||||||||||||||||
|
|
||||||||||||||||||
| # Merge additionalProperties | ||||||||||||||||||
| if additional_values: | ||||||||||||||||||
| # If any is False, result is False (most restrictive) | ||||||||||||||||||
| if any(v is False for v in additional_values): | ||||||||||||||||||
| result['additionalProperties'] = False | ||||||||||||||||||
| # If there's exactly one dict schema, preserve it | ||||||||||||||||||
| elif len(additional_values) == 1 and isinstance(additional_values[0], dict): | ||||||||||||||||||
| result['additionalProperties'] = additional_values[0] | ||||||||||||||||||
| # If any is a dict schema (multiple), result is True (can't merge multiple schemas) | ||||||||||||||||||
| elif any(isinstance(v, dict) for v in additional_values): | ||||||||||||||||||
| result['additionalProperties'] = True | ||||||||||||||||||
| # Otherwise, default to True | ||||||||||||||||||
| else: | ||||||||||||||||||
| result['additionalProperties'] = True | ||||||||||||||||||
|
|
||||||||||||||||||
| # Recursively process nested schemas (additionalProperties, patternProperties) | ||||||||||||||||||
| # Note: items is only valid for array types, not object types, so result.get('items') should never | ||||||||||||||||||
| # be present when result['type'] == 'object'. However, we keep this check for robustness. | ||||||||||||||||||
| _process_result_nested_schemas(result) | ||||||||||||||||||
|
|
||||||||||||||||||
| return result | ||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,7 +143,7 @@ class OpenAIJsonSchemaTransformer(JsonSchemaTransformer): | |
| """ | ||
|
|
||
| def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | ||
| super().__init__(schema, strict=strict) | ||
| super().__init__(schema, strict=strict, flatten_allof=True) | ||
| self.root_ref = schema.get('$ref') | ||
|
|
||
| def walk(self) -> JsonSchema: | ||
|
|
@@ -157,7 +157,12 @@ def walk(self) -> JsonSchema: | |
| if self.root_ref is not None: | ||
| result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method | ||
| root_key = re.sub(r'^#/\$defs/', '', self.root_ref) | ||
| result.update(self.defs.get(root_key) or {}) | ||
| # Use the transformed schema from $defs, not the original self.defs | ||
| if '$defs' in result and root_key in result['$defs']: | ||
| result.update(result['$defs'][root_key]) | ||
| else: | ||
| # Fallback to original if transformed version not available (shouldn't happen in normal flow) | ||
| result.update(self.defs.get(root_key) or {}) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check the comment I had on this before; I don't think it's been addressed or at least I don't understand why we still have these 2 paths instead of standardizing in only using |
||
|
|
||
| return result | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this change, it's now ambiguous which line the comment belongs to :) Can you move it back please, even though the line is uncomfortable long? Or add a blank line between the 2 vars?