diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index 791de9d92..27dce704c 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -5,6 +5,7 @@ from ._pydantic_core import ( ArgsKwargs, + GatherInvalidDefinitionError, MultiHostUrl, PydanticCustomError, PydanticKnownError, @@ -23,11 +24,12 @@ ValidationError, __version__, from_json, + gather_schemas_for_cleaning, to_json, to_jsonable_python, validate_core_schema, ) -from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType +from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, DefinitionReferenceSchema, ErrorType if _sys.version_info < (3, 11): from typing_extensions import NotRequired as _NotRequired @@ -62,11 +64,13 @@ 'PydanticUseDefault', 'PydanticSerializationError', 'PydanticSerializationUnexpectedValue', + 'GatherInvalidDefinitionError', 'TzInfo', 'to_json', 'from_json', 'to_jsonable_python', 'validate_core_schema', + 'gather_schemas_for_cleaning', ] @@ -137,3 +141,11 @@ class MultiHostHost(_TypedDict): """The host part of this host, or `None`.""" port: int | None """The port part of this host, or `None`.""" + + +class GatherResult(_TypedDict): + """Internal result of gathering schemas for cleaning.""" + + inlinable_def_refs: dict[str, DefinitionReferenceSchema | None] + recursive_refs: set[str] + schemas_with_meta_keys: dict[str, list[CoreSchema]] | None diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index f3103f28f..9155fc99a 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -5,7 +5,7 @@ from typing import Any, Callable, Generic, Literal, TypeVar, final from _typeshed import SupportsAllComparisons from typing_extensions import LiteralString, Self, TypeAlias -from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost +from pydantic_core import ErrorDetails, ErrorTypeInfo, GatherResult, InitErrorDetails, MultiHostHost from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType __all__ = [ @@ -28,6 +28,7 @@ __all__ = [ 'PydanticSerializationUnexpectedValue', 'PydanticUndefined', 'PydanticUndefinedType', + 'GatherInvalidDefinitionError', 'Some', 'to_json', 'from_json', @@ -35,6 +36,7 @@ __all__ = [ 'list_all_errors', 'TzInfo', 'validate_core_schema', + 'gather_schemas_for_cleaning', ] __version__: str build_profile: str @@ -1011,3 +1013,14 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C We may also remove this function altogether, do not rely on it being present if you are using pydantic-core directly. """ +@final +class GatherInvalidDefinitionError(Exception): + """Internal error when encountering invalid definition refs""" + +def gather_schemas_for_cleaning( + schema: CoreSchema, + definitions: dict[str, CoreSchema], + find_meta_with_keys: set[str] | None, +) -> GatherResult: + """Used internally for schema cleaning when schemas are generated. + Gathers information from the schema tree for the cleaning.""" diff --git a/src/lib.rs b/src/lib.rs index 0fdb038ea..e30d04b29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ mod errors; mod input; mod lookup_key; mod recursion_guard; +mod schema_traverse; mod serializers; mod tools; mod url; @@ -35,6 +36,7 @@ pub use build_tools::SchemaError; pub use errors::{ list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault, ValidationError, }; +pub use schema_traverse::{gather_schemas_for_cleaning, GatherInvalidDefinitionError}; pub use serializers::{ to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer, WarningsArg, @@ -129,10 +131,15 @@ fn _pydantic_core(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add( + "GatherInvalidDefinitionError", + py.get_type_bound::(), + )?; m.add_function(wrap_pyfunction!(to_json, m)?)?; m.add_function(wrap_pyfunction!(from_json, m)?)?; m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?; m.add_function(wrap_pyfunction!(list_all_errors, m)?)?; + m.add_function(wrap_pyfunction!(gather_schemas_for_cleaning, m)?)?; m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?; Ok(()) } diff --git a/src/schema_traverse.rs b/src/schema_traverse.rs new file mode 100644 index 000000000..270473191 --- /dev/null +++ b/src/schema_traverse.rs @@ -0,0 +1,201 @@ +use crate::tools::py_err; +use pyo3::exceptions::{PyException, PyKeyError}; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList, PyNone, PySet, PyString, PyTuple}; +use pyo3::{create_exception, intern, Bound, PyResult}; + +create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException); + +macro_rules! none { + ($py: expr) => { + PyNone::get_bound($py) + }; +} + +macro_rules! get { + ($dict: expr, $key: expr) => { + $dict.get_item(intern!($dict.py(), $key))? + }; +} + +macro_rules! traverse_key_fn { + ($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{ + if let Some(v) = get!($dict, $key) { + $func(v.downcast_exact()?, $ctx)? + } + }}; +} + +macro_rules! traverse { + ($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{ + $(traverse_key_fn!($key, $func, $dict, $ctx);)* + traverse_key_fn!("serialization", gather_schema, $dict, $ctx); + gather_meta($dict, $ctx) + }} +} + +macro_rules! defaultdict_list_append { + ($dict: expr, $key: expr, $value: expr) => {{ + match $dict.get_item($key)? { + None => { + let list = PyList::new_bound($dict.py(), [$value]); + $dict.set_item($key, list)?; + } + // Safety: we know that the value is a PyList as we just created it above + Some(list) => unsafe { list.downcast_unchecked::() }.append($value)?, + }; + }}; +} + +fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { + let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") else { + return py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref"); + }; + let schema_ref = schema_ref.downcast_exact::()?; + let py = schema_ref_dict.py(); + + if !ctx.recursively_seen_refs.contains(schema_ref)? { + // Def ref in no longer consider as inlinable if its re-encountered. Then its used multiple times. + // No need to retraverse it either if we already encountered this. + if !ctx.inline_def_ref_candidates.contains(schema_ref)? { + let Some(definition) = ctx.definitions.get_item(schema_ref)? else { + return py_err!(GatherInvalidDefinitionError; "{}", schema_ref.to_str()?); + }; + + ctx.inline_def_ref_candidates.set_item(schema_ref, schema_ref_dict)?; + ctx.recursively_seen_refs.add(schema_ref)?; + + gather_schema(definition.downcast_exact()?, ctx)?; + traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx); + gather_meta(schema_ref_dict, ctx)?; + + ctx.recursively_seen_refs.discard(schema_ref)?; + } else { + ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used multiple times) + } + } else { + ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used in recursion) + ctx.recursive_def_refs.add(schema_ref)?; + for seen_ref in ctx.recursively_seen_refs.iter() { + ctx.inline_def_ref_candidates.set_item(&seen_ref, none!(py))?; // Mark not inlinable (used in recursion) + ctx.recursive_def_refs.add(seen_ref)?; + } + } + Ok(()) +} + +fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { + let Some((res, find_keys)) = &ctx.meta_with_keys else { + return Ok(()); + }; + let Some(meta) = get!(schema, "metadata") else { + return Ok(()); + }; + let meta_dict = meta.downcast_exact::()?; + for k in find_keys.iter() { + if meta_dict.contains(&k)? { + defaultdict_list_append!(res, &k, schema); + } + } + Ok(()) +} + +fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> { + for v in schema_list.iter() { + gather_schema(v.downcast_exact()?, ctx)?; + } + Ok(()) +} + +fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { + for (_, v) in schemas_by_key.iter() { + gather_schema(v.downcast_exact()?, ctx)?; + } + Ok(()) +} + +fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> { + for v in schema_list.iter() { + if let Ok(tup) = v.downcast_exact::() { + gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?; + } else { + gather_schema(v.downcast_exact()?, ctx)?; + } + } + Ok(()) +} + +fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> { + for v in arguments.iter() { + traverse_key_fn!("schema", gather_schema, v.downcast_exact::()?, ctx); + } + Ok(()) +} + +// Has 100% coverage in Pydantic side. This is exclusively used there +#[cfg_attr(has_coverage_attribute, coverage(off))] +fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { + let Some(type_) = get!(schema, "type") else { + return py_err!(PyKeyError; "Schema type missing"); + }; + match type_.downcast_exact::()?.to_str()? { + "definition-ref" => gather_definition_ref(schema, ctx), + "definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx), + "list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx), + "tuple" => traverse!("items_schema" => gather_list; schema, ctx), + "dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx), + "union" => traverse!("choices" => gather_union_choices; schema, ctx), + "tagged-union" => traverse!("choices" => gather_dict; schema, ctx), + "chain" => traverse!("steps" => gather_list; schema, ctx), + "lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx), + "json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx), + "model-fields" | "typed-dict" => traverse!( + "extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx + ), + "dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx), + "arguments" => traverse!( + "arguments_schema" => gather_arguments, + "var_args_schema" => gather_schema, + "var_kwargs_schema" => gather_schema; + schema, ctx + ), + "call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx), + "computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx), + "function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx), + _ => traverse!("schema" => gather_schema; schema, ctx), + } +} + +struct GatherCtx<'a, 'py> { + definitions: &'a Bound<'py, PyDict>, + meta_with_keys: Option<(Bound<'py, PyDict>, &'a Bound<'py, PySet>)>, + inline_def_ref_candidates: Bound<'py, PyDict>, + recursive_def_refs: Bound<'py, PySet>, + recursively_seen_refs: Bound<'py, PySet>, +} + +#[pyfunction(signature = (schema, definitions, find_meta_with_keys))] +pub fn gather_schemas_for_cleaning<'py>( + schema: &Bound<'py, PyAny>, + definitions: &Bound<'py, PyAny>, + find_meta_with_keys: &Bound<'py, PyAny>, +) -> PyResult> { + let py = schema.py(); + let mut ctx = GatherCtx { + definitions: definitions.downcast_exact()?, + meta_with_keys: match find_meta_with_keys.is_none() { + true => None, + false => Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::()?)), + }, + inline_def_ref_candidates: PyDict::new_bound(py), + recursive_def_refs: PySet::empty_bound(py)?, + recursively_seen_refs: PySet::empty_bound(py)?, + }; + gather_schema(schema.downcast_exact()?, &mut ctx)?; + + let res = PyDict::new_bound(py); + res.set_item(intern!(py, "inlinable_def_refs"), ctx.inline_def_ref_candidates)?; + res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?; + res.set_item(intern!(py, "schemas_with_meta_keys"), ctx.meta_with_keys.map(|v| v.0))?; + Ok(res) +} diff --git a/tests/benchmarks/test_gather_schemas_for_cleaning_benchmark.py b/tests/benchmarks/test_gather_schemas_for_cleaning_benchmark.py new file mode 100644 index 000000000..63378e6f8 --- /dev/null +++ b/tests/benchmarks/test_gather_schemas_for_cleaning_benchmark.py @@ -0,0 +1,17 @@ +from typing import Callable + +from pydantic_core import gather_schemas_for_cleaning + +from .nested_schema import inlined_schema, schema_using_defs + + +def test_nested_schema_using_defs(benchmark: Callable[..., None]) -> None: + schema = schema_using_defs() + definitions = {def_schema['ref']: def_schema for def_schema in schema['definitions']} + schema = schema['schema'] + benchmark(gather_schemas_for_cleaning, schema, definitions, None) + + +def test_nested_schema_inlined(benchmark: Callable[..., None]) -> None: + schema = inlined_schema() + benchmark(gather_schemas_for_cleaning, schema, {}, {'some_meta_key'}) diff --git a/tests/test_gather_schemas_for_cleaning.py b/tests/test_gather_schemas_for_cleaning.py new file mode 100644 index 000000000..e6bdedaf1 --- /dev/null +++ b/tests/test_gather_schemas_for_cleaning.py @@ -0,0 +1,134 @@ +import pytest + +from pydantic_core import GatherInvalidDefinitionError, core_schema, gather_schemas_for_cleaning + + +def test_no_refs(): + p1 = core_schema.arguments_parameter('a', core_schema.int_schema()) + p2 = core_schema.arguments_parameter('b', core_schema.int_schema()) + schema = core_schema.arguments_schema([p1, p2]) + res = gather_schemas_for_cleaning(schema, definitions={}, find_meta_with_keys=None) + assert res['inlinable_def_refs'] == {} + assert res['recursive_refs'] == set() + assert res['schemas_with_meta_keys'] is None + + +def test_simple_ref_schema(): + schema = core_schema.definition_reference_schema('ref1') + definitions = {'ref1': core_schema.int_schema(ref='ref1')} + + res = gather_schemas_for_cleaning(schema, definitions, find_meta_with_keys=None) + assert res['inlinable_def_refs'] == {'ref1': schema} and res['inlinable_def_refs']['ref1'] is schema + assert res['recursive_refs'] == set() + assert res['schemas_with_meta_keys'] is None + + +def test_deep_ref_schema_used_multiple_times(): + class Model: + pass + + ref11 = core_schema.definition_reference_schema('ref1') + ref12 = core_schema.definition_reference_schema('ref1') + ref2 = core_schema.definition_reference_schema('ref2') + + union = core_schema.union_schema([core_schema.int_schema(), (ref11, 'ref_label')]) + tup = core_schema.tuple_schema([ref12, core_schema.str_schema()]) + schema = core_schema.model_schema( + Model, + core_schema.model_fields_schema( + {'a': core_schema.model_field(union), 'b': core_schema.model_field(ref2), 'c': core_schema.model_field(tup)} + ), + ) + definitions = {'ref1': core_schema.str_schema(ref='ref1'), 'ref2': core_schema.bytes_schema(ref='ref2')} + + res = gather_schemas_for_cleaning(schema, definitions, find_meta_with_keys=None) + assert res['inlinable_def_refs'] == {'ref1': None, 'ref2': ref2} and res['inlinable_def_refs']['ref2'] is ref2 + assert res['recursive_refs'] == set() + assert res['schemas_with_meta_keys'] is None + + +def test_ref_in_serialization_schema(): + ref = core_schema.definition_reference_schema('ref1') + schema = core_schema.str_schema( + serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v, return_schema=ref), + ) + res = gather_schemas_for_cleaning(schema, definitions={'ref1': core_schema.str_schema()}, find_meta_with_keys=None) + assert res['inlinable_def_refs'] == {'ref1': ref} and res['inlinable_def_refs']['ref1'] is ref + assert res['recursive_refs'] == set() + assert res['schemas_with_meta_keys'] is None + + +def test_recursive_ref_schema(): + ref1 = core_schema.definition_reference_schema('ref1') + res = gather_schemas_for_cleaning(ref1, definitions={'ref1': ref1}, find_meta_with_keys=None) + assert res['inlinable_def_refs'] == {'ref1': None} + assert res['recursive_refs'] == {'ref1'} + assert res['schemas_with_meta_keys'] is None + + +def test_deep_recursive_ref_schema(): + ref1 = core_schema.definition_reference_schema('ref1') + ref2 = core_schema.definition_reference_schema('ref2') + ref3 = core_schema.definition_reference_schema('ref3') + + res = gather_schemas_for_cleaning( + core_schema.union_schema([ref1, core_schema.int_schema()]), + definitions={ + 'ref1': core_schema.union_schema([core_schema.int_schema(), ref2]), + 'ref2': core_schema.union_schema([ref3, core_schema.float_schema()]), + 'ref3': core_schema.union_schema([ref1, core_schema.str_schema()]), + }, + find_meta_with_keys=None, + ) + assert res['inlinable_def_refs'] == {'ref1': None, 'ref2': None, 'ref3': None} + assert res['recursive_refs'] == {'ref1', 'ref2', 'ref3'} + assert res['schemas_with_meta_keys'] is None + + +def test_find_meta(): + class Model: + pass + + ref1 = core_schema.definition_reference_schema('ref1') + + field1 = core_schema.model_field(core_schema.str_schema()) + field1['metadata'] = {'find_meta1': 'foobar1', 'unknown': 'foobar2'} + + field2 = core_schema.model_field(core_schema.int_schema()) + field2['metadata'] = {'find_meta1': 'foobar3', 'find_meta2': 'foobar4'} + + schema = core_schema.model_schema( + Model, core_schema.model_fields_schema({'a': field1, 'b': ref1, 'c': core_schema.float_schema()}) + ) + res = gather_schemas_for_cleaning( + schema, definitions={'ref1': field2}, find_meta_with_keys={'find_meta1', 'find_meta2'} + ) + assert res['inlinable_def_refs'] == {'ref1': ref1} and res['inlinable_def_refs']['ref1'] is ref1 + assert res['recursive_refs'] == set() + assert res['schemas_with_meta_keys'] == {'find_meta1': [field1, field2], 'find_meta2': [field2]} + assert res['schemas_with_meta_keys']['find_meta1'][0] is field1 + assert res['schemas_with_meta_keys']['find_meta1'][1] is field2 + assert res['schemas_with_meta_keys']['find_meta2'][0] is field2 + + +def test_unknown_ref(): + ref1 = core_schema.definition_reference_schema('ref1') + schema = core_schema.tuple_schema([core_schema.int_schema(), ref1]) + with pytest.raises(GatherInvalidDefinitionError, match='ref1'): + gather_schemas_for_cleaning(schema, definitions={}, find_meta_with_keys=None) + + +def test_no_duplicate_ref_instances_gathered(): + schema1 = core_schema.tuple_schema([core_schema.str_schema(), core_schema.int_schema()]) + schema2 = core_schema.tuple_schema( + [core_schema.definition_reference_schema('ref1'), core_schema.definition_reference_schema('ref1')] + ) + schema3 = core_schema.tuple_schema( + [core_schema.definition_reference_schema('ref2'), core_schema.definition_reference_schema('ref2')] + ) + definitions = {'ref1': schema1, 'ref2': schema2} + + res = gather_schemas_for_cleaning(schema3, definitions=definitions, find_meta_with_keys=None) + assert res['inlinable_def_refs'] == {'ref1': None, 'ref2': None} + assert res['recursive_refs'] == set() + assert res['schemas_with_meta_keys'] is None