|
| 1 | +use crate::tools::py_err; |
| 2 | +use pyo3::exceptions::PyKeyError; |
| 3 | +use pyo3::prelude::*; |
| 4 | +use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple}; |
| 5 | +use pyo3::{intern, Bound, PyResult}; |
| 6 | +use std::collections::HashSet; |
| 7 | + |
| 8 | +const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY: &str = "pydantic.internal.union_discriminator"; |
| 9 | + |
| 10 | +macro_rules! get { |
| 11 | + ($dict: expr, $key: expr) => { |
| 12 | + $dict.get_item(intern!($dict.py(), $key))? |
| 13 | + }; |
| 14 | +} |
| 15 | + |
| 16 | +macro_rules! traverse_key_fn { |
| 17 | + ($key: expr, $func: expr, $dict: expr, $ctx: expr) => {{ |
| 18 | + if let Some(v) = get!($dict, $key) { |
| 19 | + $func(v.downcast_exact()?, $ctx)? |
| 20 | + } |
| 21 | + }}; |
| 22 | +} |
| 23 | + |
| 24 | +macro_rules! traverse { |
| 25 | + ($($key:expr => $func:expr),*; $dict: expr, $ctx: expr) => {{ |
| 26 | + $(traverse_key_fn!($key, $func, $dict, $ctx);)* |
| 27 | + gather_serialization($dict, $ctx)?; |
| 28 | + gather_meta($dict, $ctx)?; |
| 29 | + }} |
| 30 | +} |
| 31 | + |
| 32 | +macro_rules! defaultdict_list_append { |
| 33 | + ($dict: expr, $key: expr, $value: expr) => {{ |
| 34 | + match $dict.get_item($key)? { |
| 35 | + None => { |
| 36 | + let list = PyList::empty_bound($dict.py()); |
| 37 | + list.append($value)?; |
| 38 | + $dict.set_item($key, list)?; |
| 39 | + } |
| 40 | + // Safety: we know that the value is a PyList as we just created it above |
| 41 | + Some(list) => unsafe { list.downcast_unchecked::<PyList>() }.append($value)?, |
| 42 | + }; |
| 43 | + }}; |
| 44 | +} |
| 45 | + |
| 46 | +fn gather_definition_ref(schema_ref_dict: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 47 | + if let Some(schema_ref) = get!(schema_ref_dict, "schema_ref") { |
| 48 | + let schema_ref_pystr = schema_ref.downcast_exact::<PyString>()?; |
| 49 | + let schema_ref_str = schema_ref_pystr.to_str()?; |
| 50 | + |
| 51 | + if !ctx.recursively_seen_refs.contains(schema_ref_str) { |
| 52 | + defaultdict_list_append!(&ctx.def_refs, schema_ref_pystr, schema_ref_dict); |
| 53 | + |
| 54 | + // TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side |
| 55 | + if let Some(definition) = ctx.definitions_dict.get_item(schema_ref_pystr)? { |
| 56 | + ctx.recursively_seen_refs.insert(schema_ref_str.to_string()); |
| 57 | + |
| 58 | + gather_schema(definition.downcast_exact::<PyDict>()?, ctx)?; |
| 59 | + gather_serialization(schema_ref_dict, ctx)?; |
| 60 | + gather_meta(schema_ref_dict, ctx)?; |
| 61 | + |
| 62 | + ctx.recursively_seen_refs.remove(schema_ref_str); |
| 63 | + } |
| 64 | + } else { |
| 65 | + ctx.recursive_def_refs.add(schema_ref_pystr)?; |
| 66 | + for seen_ref in &ctx.recursively_seen_refs { |
| 67 | + let seen_ref_pystr = PyString::new_bound(schema_ref.py(), seen_ref); |
| 68 | + ctx.recursive_def_refs.add(seen_ref_pystr)?; |
| 69 | + } |
| 70 | + } |
| 71 | + Ok(()) |
| 72 | + } else { |
| 73 | + py_err!(PyKeyError; "Invalid definition-ref, missing schema_ref") |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +fn gather_serialization(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 78 | + if let Some(ser) = get!(schema, "serialization") { |
| 79 | + let ser_dict = ser.downcast_exact::<PyDict>()?; |
| 80 | + traverse!("schema" => gather_schema, "return_schema" => gather_schema; ser_dict, ctx); |
| 81 | + } |
| 82 | + Ok(()) |
| 83 | +} |
| 84 | + |
| 85 | +fn gather_meta(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 86 | + if let Some(meta) = get!(schema, "metadata") { |
| 87 | + let meta_dict = meta.downcast_exact::<PyDict>()?; |
| 88 | + if let Some(discriminator) = get!(meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY) { |
| 89 | + let schema_discriminator = PyTuple::new_bound(schema.py(), vec![schema.as_any(), &discriminator]); |
| 90 | + ctx.discriminators.append(schema_discriminator)?; |
| 91 | + } |
| 92 | + } |
| 93 | + Ok(()) |
| 94 | +} |
| 95 | + |
| 96 | +fn gather_list(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 97 | + for v in schema_list.iter() { |
| 98 | + gather_schema(v.downcast_exact()?, ctx)?; |
| 99 | + } |
| 100 | + Ok(()) |
| 101 | +} |
| 102 | + |
| 103 | +fn gather_dict(schemas_by_key: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 104 | + for (_, v) in schemas_by_key.iter() { |
| 105 | + gather_schema(v.downcast_exact()?, ctx)?; |
| 106 | + } |
| 107 | + Ok(()) |
| 108 | +} |
| 109 | + |
| 110 | +fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 111 | + for v in schema_list.iter() { |
| 112 | + if let Ok(tup) = v.downcast_exact::<PyTuple>() { |
| 113 | + gather_schema(tup.get_item(0)?.downcast_exact()?, ctx)?; |
| 114 | + } else { |
| 115 | + gather_schema(v.downcast_exact()?, ctx)?; |
| 116 | + } |
| 117 | + } |
| 118 | + Ok(()) |
| 119 | +} |
| 120 | + |
| 121 | +fn gather_arguments(arguments: &Bound<'_, PyList>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 122 | + for v in arguments.iter() { |
| 123 | + traverse_key_fn!("schema", gather_schema, v.downcast_exact::<PyDict>()?, ctx); |
| 124 | + } |
| 125 | + Ok(()) |
| 126 | +} |
| 127 | + |
| 128 | +// Has 100% coverage in Pydantic side. This is exclusively used there |
| 129 | +#[cfg_attr(has_coverage_attribute, coverage(off))] |
| 130 | +fn gather_schema(schema: &Bound<'_, PyDict>, ctx: &mut GatherCtx) -> PyResult<()> { |
| 131 | + let type_ = get!(schema, "type"); |
| 132 | + if type_.is_none() { |
| 133 | + return py_err!(PyKeyError; "Schema type missing"); |
| 134 | + } |
| 135 | + match type_.unwrap().downcast_exact::<PyString>()?.to_str()? { |
| 136 | + "definition-ref" => gather_definition_ref(schema, ctx)?, |
| 137 | + "definitions" => traverse!("schema" => gather_schema, "definitions" => gather_list; schema, ctx), |
| 138 | + "list" | "set" | "frozenset" | "generator" => traverse!("items_schema" => gather_schema; schema, ctx), |
| 139 | + "tuple" => traverse!("items_schema" => gather_list; schema, ctx), |
| 140 | + "dict" => traverse!("keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx), |
| 141 | + "union" => traverse!("choices" => gather_union_choices; schema, ctx), |
| 142 | + "tagged-union" => traverse!("choices" => gather_dict; schema, ctx), |
| 143 | + "chain" => traverse!("steps" => gather_list; schema, ctx), |
| 144 | + "lax-or-strict" => traverse!("lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx), |
| 145 | + "json-or-python" => traverse!("json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx), |
| 146 | + "model-fields" | "typed-dict" => traverse!( |
| 147 | + "extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx |
| 148 | + ), |
| 149 | + "dataclass-args" => traverse!("computed_fields" => gather_list, "fields" => gather_list; schema, ctx), |
| 150 | + "arguments" => traverse!( |
| 151 | + "arguments_schema" => gather_arguments, |
| 152 | + "var_args_schema" => gather_schema, |
| 153 | + "var_kwargs_schema" => gather_schema; |
| 154 | + schema, ctx |
| 155 | + ), |
| 156 | + "call" => traverse!("arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx), |
| 157 | + "computed-field" | "function-plain" => traverse!("return_schema" => gather_schema; schema, ctx), |
| 158 | + "function-wrap" => traverse!("return_schema" => gather_schema, "schema" => gather_schema; schema, ctx), |
| 159 | + _ => traverse!("schema" => gather_schema; schema, ctx), |
| 160 | + }; |
| 161 | + Ok(()) |
| 162 | +} |
| 163 | + |
| 164 | +pub struct GatherCtx<'a, 'py> { |
| 165 | + pub definitions_dict: &'a Bound<'py, PyDict>, |
| 166 | + pub def_refs: Bound<'py, PyDict>, |
| 167 | + pub recursive_def_refs: Bound<'py, PySet>, |
| 168 | + pub discriminators: Bound<'py, PyList>, |
| 169 | + recursively_seen_refs: HashSet<String>, |
| 170 | +} |
| 171 | + |
| 172 | +#[pyfunction(signature = (schema, definitions))] |
| 173 | +pub fn gather_schemas_for_cleaning<'py>( |
| 174 | + schema: &Bound<'py, PyAny>, |
| 175 | + definitions: &Bound<'py, PyAny>, |
| 176 | +) -> PyResult<Bound<'py, PyDict>> { |
| 177 | + let py = schema.py(); |
| 178 | + let schema_dict = schema.downcast_exact::<PyDict>()?; |
| 179 | + |
| 180 | + let mut ctx = GatherCtx { |
| 181 | + definitions_dict: definitions.downcast_exact()?, |
| 182 | + def_refs: PyDict::new_bound(definitions.py()), |
| 183 | + recursive_def_refs: PySet::empty_bound(definitions.py())?, |
| 184 | + discriminators: PyList::empty_bound(definitions.py()), |
| 185 | + recursively_seen_refs: HashSet::new(), |
| 186 | + }; |
| 187 | + gather_schema(schema_dict, &mut ctx)?; |
| 188 | + |
| 189 | + let res = PyDict::new_bound(py); |
| 190 | + res.set_item(intern!(py, "definition_refs"), ctx.def_refs)?; |
| 191 | + res.set_item(intern!(py, "recursive_refs"), ctx.recursive_def_refs)?; |
| 192 | + res.set_item(intern!(py, "deferred_discriminators"), ctx.discriminators)?; |
| 193 | + Ok(res) |
| 194 | +} |
0 commit comments