@@ -8,19 +8,27 @@ use std::collections::HashSet;
88const CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY : & str = "pydantic.internal.union_discriminator" ;
99
1010macro_rules! get {
11- ( $dict: expr, $key: expr) => { {
11+ ( $dict: expr, $key: expr) => {
1212 $dict. get_item( intern!( $dict. py( ) , $key) ) ?
13- } } ;
13+ } ;
1414}
1515
16- macro_rules! traverse {
17- ( $func : expr, $key : expr, $dict: expr, $ctx: expr) => { {
18- if let Some ( v) = $dict . get_item ( intern !( $dict. py ( ) , $key) ) ? {
16+ macro_rules! traverse_key_fn {
17+ ( $key : expr, $func : expr, $dict: expr, $ctx: expr) => { {
18+ if let Some ( v) = get !( $dict, $key) {
1919 $func( v. downcast_exact( ) ?, $ctx) ?
2020 }
2121 } } ;
2222}
2323
24+ macro_rules! traverse {
25+ ( $( $key: expr => $func: expr) ,* ; $dict: expr, $ctx: expr) => { {
26+ $( traverse_key_fn!( $key, $func, $dict, $ctx) ; ) *
27+ gather_serialization( $dict, $ctx) ?;
28+ gather_meta( $dict, $ctx) ?;
29+ } }
30+ }
31+
2432macro_rules! defaultdict_list_append {
2533 ( $dict: expr, $key: expr, $value: expr) => { {
2634 match $dict. get_item( $key) ? {
@@ -35,53 +43,51 @@ macro_rules! defaultdict_list_append {
3543 } } ;
3644}
3745
38- fn gather_definition_ref ( schema_ref_dict : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < bool > {
46+ fn gather_definition_ref ( schema_ref_dict : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
3947 if let Some ( schema_ref) = get ! ( schema_ref_dict, "schema_ref" ) {
4048 let schema_ref_pystr = schema_ref. downcast_exact :: < PyString > ( ) ?;
4149 let schema_ref_str = schema_ref_pystr. to_str ( ) ?;
4250 defaultdict_list_append ! ( & ctx. def_refs, schema_ref_pystr, schema_ref_dict) ;
4351
4452 if !ctx. recursively_seen_refs . contains ( schema_ref_str) {
4553 // TODO should py_err! when not found. That error can be used to detect the missing defs in cleaning side
46- if let Some ( def ) = ctx. definitions_dict . get_item ( schema_ref_pystr) ? {
54+ if let Some ( definition ) = ctx. definitions_dict . get_item ( schema_ref_pystr) ? {
4755 ctx. recursively_seen_refs . insert ( schema_ref_str. to_string ( ) ) ;
48- gather_schema ( def. downcast_exact :: < PyDict > ( ) ?, ctx) ?;
56+
57+ gather_schema ( definition. downcast_exact :: < PyDict > ( ) ?, ctx) ?;
58+ gather_serialization ( schema_ref_dict, ctx) ?;
59+ gather_meta ( schema_ref_dict, ctx) ?;
60+
4961 ctx. recursively_seen_refs . remove ( schema_ref_str) ;
5062 }
51- Ok ( false )
5263 } else {
5364 ctx. recursive_def_refs . add ( schema_ref_pystr) ?;
54- for r in & ctx. recursively_seen_refs {
55- ctx. recursive_def_refs . add ( PyString :: new_bound ( schema_ref. py ( ) , r) ) ?;
65+ for seen_ref in & ctx. recursively_seen_refs {
66+ let seen_ref_pystr = PyString :: new_bound ( schema_ref. py ( ) , seen_ref) ;
67+ ctx. recursive_def_refs . add ( seen_ref_pystr) ?;
5668 }
57- Ok ( true )
5869 }
70+ Ok ( ( ) )
5971 } else {
60- py_err ! ( PyKeyError ; "Invalid definition-ref, missing schema_ref" ) ?
72+ py_err ! ( PyKeyError ; "Invalid definition-ref, missing schema_ref" )
6173 }
6274}
6375
64- fn gather_meta ( schema : & Bound < ' _ , PyDict > , meta_dict : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
65- if let Some ( discriminator ) = get ! ( meta_dict , CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY ) {
66- let schema_discriminator = PyTuple :: new_bound ( schema . py ( ) , vec ! [ schema . as_any ( ) , & discriminator ] ) ;
67- ctx . discriminators . append ( schema_discriminator ) ? ;
76+ fn gather_serialization ( schema : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
77+ if let Some ( ser ) = get ! ( schema , "serialization" ) {
78+ let ser_dict = ser . downcast_exact :: < PyDict > ( ) ? ;
79+ traverse ! ( "schema" => gather_schema , "return_schema" => gather_schema ; ser_dict , ctx ) ;
6880 }
6981 Ok ( ( ) )
7082}
7183
72- fn gather_node ( schema : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
73- let type_ = get ! ( schema, "type" ) ;
74- if type_. is_none ( ) {
75- return py_err ! ( PyValueError ; "Schema type missing" ) ;
76- }
77- if type_. unwrap ( ) . downcast_exact :: < PyString > ( ) ?. to_str ( ) ? == "definition-ref" {
78- let recursive = gather_definition_ref ( schema, ctx) ?;
79- if recursive {
80- return Ok ( ( ) ) ;
81- }
82- }
84+ fn gather_meta ( schema : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
8385 if let Some ( meta) = get ! ( schema, "metadata" ) {
84- gather_meta ( schema, meta. downcast_exact ( ) ?, ctx) ?;
86+ let meta_dict = meta. downcast_exact :: < PyDict > ( ) ?;
87+ if let Some ( discriminator) = get ! ( meta_dict, CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY ) {
88+ let schema_discriminator = PyTuple :: new_bound ( schema. py ( ) , vec ! [ schema. as_any( ) , & discriminator] ) ;
89+ ctx. discriminators . append ( schema_discriminator) ?;
90+ }
8591 }
8692 Ok ( ( ) )
8793}
@@ -113,71 +119,42 @@ fn gather_union_choices(schema_list: &Bound<'_, PyList>, ctx: &mut GatherCtx) ->
113119
114120fn gather_arguments ( arguments : & Bound < ' _ , PyList > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
115121 for v in arguments. iter ( ) {
116- if let Some ( schema) = get ! ( v. downcast_exact:: <PyDict >( ) ?, "schema" ) {
117- gather_schema ( schema. downcast_exact ( ) ?, ctx) ?;
118- }
122+ traverse_key_fn ! ( "schema" , gather_schema, v. downcast_exact:: <PyDict >( ) ?, ctx) ;
119123 }
120124 Ok ( ( ) )
121125}
122126
123- fn traverse_schema ( schema : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
127+ fn gather_schema ( schema : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
124128 let type_ = get ! ( schema, "type" ) ;
125129 if type_. is_none ( ) {
126130 return py_err ! ( PyValueError ; "Schema type missing" ) ;
127131 }
128132 match type_. unwrap ( ) . downcast_exact :: < PyString > ( ) ?. to_str ( ) ? {
129- "definitions" => {
130- traverse ! ( gather_schema, "schema" , schema, ctx) ;
131- traverse ! ( gather_list, "definitions" , schema, ctx) ;
132- }
133- "list" | "set" | "frozenset" | "generator" => traverse ! ( gather_schema, "items_schema" , schema, ctx) ,
134- "tuple" => traverse ! ( gather_list, "items_schema" , schema, ctx) ,
135- "dict" => {
136- traverse ! ( gather_schema, "keys_schema" , schema, ctx) ;
137- traverse ! ( gather_schema, "values_schema" , schema, ctx) ;
138- }
139- "union" => traverse ! ( gather_union_choices, "choices" , schema, ctx) ,
140- "tagged-union" => traverse ! ( gather_dict, "choices" , schema, ctx) ,
141- "chain" => traverse ! ( gather_list, "steps" , schema, ctx) ,
142- "lax-or-strict" => {
143- traverse ! ( gather_schema, "lax_schema" , schema, ctx) ;
144- traverse ! ( gather_schema, "strict_schema" , schema, ctx) ;
145- }
146- "json-or-python" => {
147- traverse ! ( gather_schema, "json_schema" , schema, ctx) ;
148- traverse ! ( gather_schema, "python_schema" , schema, ctx) ;
149- }
150- "model-fields" | "typed-dict" => {
151- traverse ! ( gather_schema, "extras_schema" , schema, ctx) ;
152- traverse ! ( gather_list, "computed_fields" , schema, ctx) ;
153- traverse ! ( gather_dict, "fields" , schema, ctx) ;
154- }
155- "dataclass-args" => {
156- traverse ! ( gather_list, "computed_fields" , schema, ctx) ;
157- traverse ! ( gather_list, "fields" , schema, ctx) ;
158- }
159- "arguments" => {
160- traverse ! ( gather_arguments, "arguments_schema" , schema, ctx) ;
161- traverse ! ( gather_schema, "var_args_schema" , schema, ctx) ;
162- traverse ! ( gather_schema, "var_kwargs_schema" , schema, ctx) ;
163- }
164- "computed-field" | "function-plain" => traverse ! ( gather_schema, "return_schema" , schema, ctx) ,
165- "function-wrap" => {
166- traverse ! ( gather_schema, "return_schema" , schema, ctx) ;
167- traverse ! ( gather_schema, "schema" , schema, ctx) ;
168- }
169- "call" => {
170- traverse ! ( gather_schema, "arguments_schema" , schema, ctx) ;
171- traverse ! ( gather_schema, "return_schema" , schema, ctx) ;
172- }
173- _ => traverse ! ( gather_schema, "schema" , schema, ctx) ,
133+ "definition-ref" => gather_definition_ref ( schema, ctx) ?,
134+ "definitions" => traverse ! ( "schema" => gather_schema, "definitions" => gather_list; schema, ctx) ,
135+ "list" | "set" | "frozenset" | "generator" => traverse ! ( "items_schema" => gather_schema; schema, ctx) ,
136+ "tuple" => traverse ! ( "items_schema" => gather_list; schema, ctx) ,
137+ "dict" => traverse ! ( "keys_schema" => gather_schema, "values_schema" => gather_schema; schema, ctx) ,
138+ "union" => traverse ! ( "choices" => gather_union_choices; schema, ctx) ,
139+ "tagged-union" => traverse ! ( "choices" => gather_dict; schema, ctx) ,
140+ "chain" => traverse ! ( "steps" => gather_list; schema, ctx) ,
141+ "lax-or-strict" => traverse ! ( "lax_schema" => gather_schema, "strict_schema" => gather_schema; schema, ctx) ,
142+ "json-or-python" => traverse ! ( "json_schema" => gather_schema, "python_schema" => gather_schema; schema, ctx) ,
143+ "model-fields" | "typed-dict" => traverse ! (
144+ "extras_schema" => gather_schema, "computed_fields" => gather_list, "fields" => gather_dict; schema, ctx
145+ ) ,
146+ "dataclass-args" => traverse ! ( "computed_fields" => gather_list, "fields" => gather_list; schema, ctx) ,
147+ "arguments" => traverse ! (
148+ "arguments_schema" => gather_arguments,
149+ "var_args_schema" => gather_schema,
150+ "var_kwargs_schema" => gather_schema;
151+ schema, ctx
152+ ) ,
153+ "call" => traverse ! ( "arguments_schema" => gather_schema, "return_schema" => gather_schema; schema, ctx) ,
154+ "computed-field" | "function-plain" => traverse ! ( "return_schema" => gather_schema; schema, ctx) ,
155+ "function-wrap" => traverse ! ( "return_schema" => gather_schema, "schema" => gather_schema; schema, ctx) ,
156+ _ => traverse ! ( "schema" => gather_schema; schema, ctx) ,
174157 } ;
175-
176- if let Some ( ser) = get ! ( schema, "serialization" ) {
177- let ser_dict = ser. downcast_exact :: < PyDict > ( ) ?;
178- traverse ! ( gather_schema, "schema" , ser_dict, ctx) ;
179- traverse ! ( gather_schema, "return_schema" , ser_dict, ctx) ;
180- }
181158 Ok ( ( ) )
182159}
183160
@@ -189,24 +166,6 @@ pub struct GatherCtx<'a, 'py> {
189166 recursively_seen_refs : HashSet < String > ,
190167}
191168
192- impl < ' a , ' py > GatherCtx < ' a , ' py > {
193- pub fn new ( definitions : & ' a Bound < ' py , PyDict > ) -> PyResult < Self > {
194- let ctx = GatherCtx {
195- definitions_dict : definitions,
196- def_refs : PyDict :: new_bound ( definitions. py ( ) ) ,
197- recursive_def_refs : PySet :: empty_bound ( definitions. py ( ) ) ?,
198- discriminators : PyList :: empty_bound ( definitions. py ( ) ) ,
199- recursively_seen_refs : HashSet :: new ( ) ,
200- } ;
201- Ok ( ctx)
202- }
203- }
204-
205- fn gather_schema ( schema : & Bound < ' _ , PyDict > , ctx : & mut GatherCtx ) -> PyResult < ( ) > {
206- traverse_schema ( schema, ctx) ?;
207- gather_node ( schema, ctx)
208- }
209-
210169#[ pyfunction( signature = ( schema, definitions) ) ]
211170pub fn gather_schemas_for_cleaning < ' py > (
212171 schema : & Bound < ' py , PyAny > ,
@@ -215,7 +174,13 @@ pub fn gather_schemas_for_cleaning<'py>(
215174 let py = schema. py ( ) ;
216175 let schema_dict = schema. downcast_exact :: < PyDict > ( ) ?;
217176
218- let mut ctx = GatherCtx :: new ( definitions. downcast_exact ( ) ?) ?;
177+ let mut ctx = GatherCtx {
178+ definitions_dict : definitions. downcast_exact ( ) ?,
179+ def_refs : PyDict :: new_bound ( definitions. py ( ) ) ,
180+ recursive_def_refs : PySet :: empty_bound ( definitions. py ( ) ) ?,
181+ discriminators : PyList :: empty_bound ( definitions. py ( ) ) ,
182+ recursively_seen_refs : HashSet :: new ( ) ,
183+ } ;
219184 gather_schema ( schema_dict, & mut ctx) ?;
220185
221186 let res = PyDict :: new_bound ( py) ;
0 commit comments