| 
 | 1 | +use crate::tools::py_err;  | 
 | 2 | +use pyo3::exceptions::{PyException, PyKeyError};  | 
 | 3 | +use pyo3::prelude::*;  | 
 | 4 | +use pyo3::types::{PyDict, PyList, PyNone, PySet, PyString, PyTuple};  | 
 | 5 | +use pyo3::{create_exception, intern, Bound, PyResult};  | 
 | 6 | + | 
 | 7 | +create_exception!(pydantic_core._pydantic_core, GatherInvalidDefinitionError, PyException);  | 
 | 8 | + | 
 | 9 | +macro_rules! none {  | 
 | 10 | +    ($py: expr) => {  | 
 | 11 | +        PyNone::get_bound($py)  | 
 | 12 | +    };  | 
 | 13 | +}  | 
 | 14 | + | 
 | 15 | +macro_rules! get {  | 
 | 16 | +    ($dict: expr, $key: expr) => {  | 
 | 17 | +        $dict.get_item(intern!($dict.py(), $key))?  | 
 | 18 | +    };  | 
 | 19 | +}  | 
 | 20 | + | 
 | 21 | +macro_rules! traverse_key_fn {  | 
 | 22 | +    ($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{  | 
 | 23 | +        if let Some(v) = get!($dict, $key) {  | 
 | 24 | +            $func(v.downcast_exact()?, $ctx)?  | 
 | 25 | +        }  | 
 | 26 | +    }};  | 
 | 27 | +}  | 
 | 28 | + | 
 | 29 | +macro_rules! traverse {  | 
 | 30 | +    ($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{  | 
 | 31 | +        $(traverse_key_fn!($key, $func, $dict, $ctx);)*  | 
 | 32 | +        traverse_key_fn!("serialization", gather_schema, $dict, $ctx);  | 
 | 33 | +        gather_meta($dict, $ctx)  | 
 | 34 | +    }}  | 
 | 35 | +}  | 
 | 36 | + | 
 | 37 | +macro_rules! defaultdict_list_append {  | 
 | 38 | +    ($dict: expr, $key: expr, $value: expr) => {{  | 
 | 39 | +        match $dict.get_item($key)? {  | 
 | 40 | +            None => {  | 
 | 41 | +                let list = PyList::new_bound($dict.py(), [$value]);  | 
 | 42 | +                $dict.set_item($key, list)?;  | 
 | 43 | +            }  | 
 | 44 | +            // Safety: we know that the value is a PyList as we just created it above  | 
 | 45 | +            Some(list) => unsafe { list.downcast_unchecked::<PyList>() }.append($value)?,  | 
 | 46 | +        };  | 
 | 47 | +    }};  | 
 | 48 | +}  | 
 | 49 | + | 
 | 50 | +fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {  | 
 | 51 | +    let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") else {  | 
 | 52 | +        return py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref");  | 
 | 53 | +    };  | 
 | 54 | +    let schema_ref = schema_ref.downcast_exact::<PyString>()?;  | 
 | 55 | +    let py = schema_ref_dict.py();  | 
 | 56 | + | 
 | 57 | +    if !ctx.recursively_seen_refs.contains(schema_ref)? {  | 
 | 58 | +        // Def ref in no longer consider as inlinable if its re-encountered. Then its used multiple times.  | 
 | 59 | +        // No need to retraverse it either if we already encountered this.  | 
 | 60 | +        if !ctx.inline_def_ref_candidates.contains(schema_ref)? {  | 
 | 61 | +            let Some(definition) = ctx.definitions.get_item(schema_ref)? else {  | 
 | 62 | +                return py_err!(GatherInvalidDefinitionError; "{}", schema_ref.to_str()?);  | 
 | 63 | +            };  | 
 | 64 | + | 
 | 65 | +            ctx.inline_def_ref_candidates.set_item(schema_ref, schema_ref_dict)?;  | 
 | 66 | +            ctx.recursively_seen_refs.add(schema_ref)?;  | 
 | 67 | + | 
 | 68 | +            gather_schema(definition.downcast_exact()?, ctx)?;  | 
 | 69 | +            traverse_key_fn!("serialization", gather_schema, schema_ref_dict, ctx);  | 
 | 70 | +            gather_meta(schema_ref_dict, ctx)?;  | 
 | 71 | + | 
 | 72 | +            ctx.recursively_seen_refs.discard(schema_ref)?;  | 
 | 73 | +        } else {  | 
 | 74 | +            ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used multiple times)  | 
 | 75 | +        }  | 
 | 76 | +    } else {  | 
 | 77 | +        ctx.inline_def_ref_candidates.set_item(schema_ref, none!(py))?; // Mark not inlinable (used in recursion)  | 
 | 78 | +        ctx.recursive_def_refs.add(schema_ref)?;  | 
 | 79 | +        for seen_ref in ctx.recursively_seen_refs.iter() {  | 
 | 80 | +            ctx.inline_def_ref_candidates.set_item(&seen_ref, none!(py))?; // Mark not inlinable (used in recursion)  | 
 | 81 | +            ctx.recursive_def_refs.add(seen_ref)?;  | 
 | 82 | +        }  | 
 | 83 | +    }  | 
 | 84 | +    Ok(())  | 
 | 85 | +}  | 
 | 86 | + | 
 | 87 | +fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {  | 
 | 88 | +    let Some((res, find_keys)) = &ctx.meta_with_keys else {  | 
 | 89 | +        return Ok(());  | 
 | 90 | +    };  | 
 | 91 | +    let Some(meta) = get!(schema, "metadata") else {  | 
 | 92 | +        return Ok(());  | 
 | 93 | +    };  | 
 | 94 | +    let meta_dict = meta.downcast_exact::<PyDict>()?;  | 
 | 95 | +    for k in find_keys.iter() {  | 
 | 96 | +        if meta_dict.contains(&k)? {  | 
 | 97 | +            defaultdict_list_append!(res, &k, schema);  | 
 | 98 | +        }  | 
 | 99 | +    }  | 
 | 100 | +    Ok(())  | 
 | 101 | +}  | 
 | 102 | + | 
 | 103 | +fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {  | 
 | 104 | +    for v in schema_list.iter() {  | 
 | 105 | +        gather_schema(v.downcast_exact()?, ctx)?;  | 
 | 106 | +    }  | 
 | 107 | +    Ok(())  | 
 | 108 | +}  | 
 | 109 | + | 
 | 110 | +fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {  | 
 | 111 | +    for (_, v) in schemas_by_key.iter() {  | 
 | 112 | +        gather_schema(v.downcast_exact()?, ctx)?;  | 
 | 113 | +    }  | 
 | 114 | +    Ok(())  | 
 | 115 | +}  | 
 | 116 | + | 
 | 117 | +fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {  | 
 | 118 | +    for v in schema_list.iter() {  | 
 | 119 | +        if let Ok(tup) = v.downcast_exact::<PyTuple>() {  | 
 | 120 | +            gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?;  | 
 | 121 | +        } else {  | 
 | 122 | +            gather_schema(v.downcast_exact()?, ctx)?;  | 
 | 123 | +        }  | 
 | 124 | +    }  | 
 | 125 | +    Ok(())  | 
 | 126 | +}  | 
 | 127 | + | 
 | 128 | +fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> {  | 
 | 129 | +    for v in arguments.iter() {  | 
 | 130 | +        traverse_key_fn!("schema", gather_schema, v.downcast_exact::<PyDict>()?, ctx);  | 
 | 131 | +    }  | 
 | 132 | +    Ok(())  | 
 | 133 | +}  | 
 | 134 | + | 
 | 135 | +// Has 100% coverage in Pydantic side. This is exclusively used there  | 
 | 136 | +#[cfg_attr(has_coverage_attribute, coverage(off))]  | 
 | 137 | +fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> {  | 
 | 138 | +    let Some(type_) = get!(schema, "type") else {  | 
 | 139 | +        return py_err!(PyKeyError; "Schema type missing");  | 
 | 140 | +    };  | 
 | 141 | +    match type_.downcast_exact::<PyString>()?.to_str()? {  | 
 | 142 | +        "definition-ref" => gather_definition_ref(schema, ctx),  | 
 | 143 | +        "definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx),  | 
 | 144 | +        "list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx),  | 
 | 145 | +        "tuple" => traverse!("items_schema" => gather_list; schema, ctx),  | 
 | 146 | +        "dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx),  | 
 | 147 | +        "union" => traverse!("choices" => gather_union_choices; schema, ctx),  | 
 | 148 | +        "tagged-union" => traverse!("choices" => gather_dict; schema, ctx),  | 
 | 149 | +        "chain" => traverse!("steps" => gather_list; schema, ctx),  | 
 | 150 | +        "lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx),  | 
 | 151 | +        "json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx),  | 
 | 152 | +        "model-fields" | "typed-dict" => traverse!(  | 
 | 153 | +            "extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx  | 
 | 154 | +        ),  | 
 | 155 | +        "dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx),  | 
 | 156 | +        "arguments" => traverse!(  | 
 | 157 | +            "arguments_schema" => gather_arguments,  | 
 | 158 | +            "var_args_schema" => gather_schema,  | 
 | 159 | +            "var_kwargs_schema" => gather_schema;  | 
 | 160 | +            schema, ctx  | 
 | 161 | +        ),  | 
 | 162 | +        "call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx),  | 
 | 163 | +        "computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx),  | 
 | 164 | +        "function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx),  | 
 | 165 | +        _ => traverse!("schema" => gather_schema; schema, ctx),  | 
 | 166 | +    }  | 
 | 167 | +}  | 
 | 168 | + | 
 | 169 | +struct GatherCtx<'a, 'py> {  | 
 | 170 | +    definitions: &'a Bound<'py, PyDict>,  | 
 | 171 | +    meta_with_keys: Option<(Bound<'py, PyDict>, &'a Bound<'py, PySet>)>,  | 
 | 172 | +    inline_def_ref_candidates: Bound<'py, PyDict>,  | 
 | 173 | +    recursive_def_refs: Bound<'py, PySet>,  | 
 | 174 | +    recursively_seen_refs: Bound<'py, PySet>,  | 
 | 175 | +}  | 
 | 176 | + | 
 | 177 | +#[pyfunction(signature = (schema, definitions, find_meta_with_keys))]  | 
 | 178 | +pub fn gather_schemas_for_cleaning<'py>(  | 
 | 179 | +    schema: &Bound<'py, PyAny>,  | 
 | 180 | +    definitions: &Bound<'py, PyAny>,  | 
 | 181 | +    find_meta_with_keys: &Bound<'py, PyAny>,  | 
 | 182 | +) -> PyResult<Bound<'py, PyDict>> {  | 
 | 183 | +    let py = schema.py();  | 
 | 184 | +    let mut ctx = GatherCtx {  | 
 | 185 | +        definitions: definitions.downcast_exact()?,  | 
 | 186 | +        meta_with_keys: match find_meta_with_keys.is_none() {  | 
 | 187 | +            true => None,  | 
 | 188 | +            false => Some((PyDict::new_bound(py), find_meta_with_keys.downcast_exact::<PySet>()?)),  | 
 | 189 | +        },  | 
 | 190 | +        inline_def_ref_candidates: PyDict::new_bound(py),  | 
 | 191 | +        recursive_def_refs: PySet::empty_bound(py)?,  | 
 | 192 | +        recursively_seen_refs: PySet::empty_bound(py)?,  | 
 | 193 | +    };  | 
 | 194 | +    gather_schema(schema.downcast_exact()?, &mut ctx)?;  | 
 | 195 | + | 
 | 196 | +    let res = PyDict::new_bound(py);  | 
 | 197 | +    res.set_item(intern!(py, "inlinable_def_refs"), ctx.inline_def_ref_candidates)?;  | 
 | 198 | +    res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?;  | 
 | 199 | +    res.set_item(intern!(py, "schemas_with_meta_keys"), ctx.meta_with_keys.map(|v| v.0))?;  | 
 | 200 | +    Ok(res)  | 
 | 201 | +}  | 
0 commit comments