Skip to content

Commit 7ea0d64

Browse files
committed
checkpoint - using tagged_union_serializer too
1 parent 57c3c6f commit 7ea0d64

File tree

1 file changed

+103
-84
lines changed
  • src/serializers/type_serializers

1 file changed

+103
-84
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 103 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use pyo3::prelude::*;
44
use pyo3::types::{PyDict, PyList, PyTuple};
55
use smallvec::SmallVec;
66
use std::borrow::Cow;
7+
use std::sync::Arc;
78

89
use crate::build_tools::py_schema_err;
910
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
@@ -119,6 +120,41 @@ fn union_serialize<S, R>(
119120
Ok(finalizer(None))
120121
}
121122

123+
fn tagged_union_serialize<S>(
124+
discriminator_value: Option<Py<PyAny>>,
125+
lookup: &HashMap<String, usize>,
126+
// if this returns `Ok(v)`, we picked a union variant to serialize, where
127+
// `S` is intermediate state which can be passed on to the finalizer
128+
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
129+
extra: &Extra,
130+
choices: &Vec<CombinedSerializer>,
131+
retry_with_lax_check: bool,
132+
) -> Option<S> {
133+
let mut new_extra = extra.clone();
134+
new_extra.check = SerCheck::Strict;
135+
136+
if let Some(tag) = discriminator_value {
137+
let tag_str = tag.to_string();
138+
if let Some(&serializer_index) = lookup.get(&tag_str) {
139+
let selected_serializer = &choices[serializer_index];
140+
141+
match selector(&selected_serializer, &new_extra) {
142+
Ok(v) => return Some(v),
143+
Err(_) => {
144+
if retry_with_lax_check {
145+
new_extra.check = SerCheck::Lax;
146+
if let Ok(v) = selector(&selected_serializer, &new_extra) {
147+
return Some(v);
148+
}
149+
}
150+
}
151+
}
152+
}
153+
}
154+
155+
None
156+
}
157+
122158
impl TypeSerializer for UnionSerializer {
123159
fn to_python(
124160
&self,
@@ -237,67 +273,56 @@ impl TypeSerializer for TaggedUnionSerializer {
237273
exclude: Option<&Bound<'_, PyAny>>,
238274
extra: &Extra,
239275
) -> PyResult<PyObject> {
240-
let mut new_extra = extra.clone();
241-
new_extra.check = SerCheck::Strict;
242-
243-
if let Some(tag) = self.get_discriminator_value(value, extra) {
244-
let tag_str = tag.to_string();
245-
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
246-
let serializer = &self.choices[serializer_index];
247-
248-
match serializer.to_python(value, include, exclude, &new_extra) {
249-
Ok(v) => return Ok(v),
250-
Err(_) => {
251-
if self.retry_with_lax_check() {
252-
new_extra.check = SerCheck::Lax;
253-
if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) {
254-
return Ok(v);
255-
}
256-
}
257-
}
258-
}
259-
}
260-
}
276+
let to_python_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| {
277+
comb_serializer.to_python(value, include, exclude, new_extra)
278+
};
261279

262-
union_serialize(
263-
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
264-
|v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok),
280+
tagged_union_serialize(
281+
self.get_discriminator_value(value, extra),
282+
&self.lookup,
283+
to_python_selector,
265284
extra,
266285
&self.choices,
267286
self.retry_with_lax_check(),
268-
)?
287+
)
288+
.map_or_else(
289+
|| {
290+
union_serialize(
291+
to_python_selector,
292+
|v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok),
293+
extra,
294+
&self.choices,
295+
self.retry_with_lax_check(),
296+
)?
297+
},
298+
Ok,
299+
)
269300
}
270301

271302
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
272-
let mut new_extra = extra.clone();
273-
new_extra.check = SerCheck::Strict;
274-
275-
if let Some(tag) = self.get_discriminator_value(key, extra) {
276-
let tag_str = tag.to_string();
277-
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
278-
let serializer = &self.choices[serializer_index];
279-
280-
match serializer.json_key(key, &new_extra) {
281-
Ok(v) => return Ok(v),
282-
Err(_) => {
283-
if self.retry_with_lax_check() {
284-
new_extra.check = SerCheck::Lax;
285-
if let Ok(v) = serializer.json_key(key, &new_extra) {
286-
return Ok(v);
287-
}
288-
}
289-
}
290-
}
291-
}
292-
}
303+
let json_key_selector =
304+
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra);
293305

294-
union_serialize(
295-
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
296-
|v| v.map_or_else(|| infer_json_key(key, extra), Ok),
306+
tagged_union_serialize(
307+
self.get_discriminator_value(key, extra),
308+
&self.lookup,
309+
json_key_selector,
297310
extra,
298311
&self.choices,
299312
self.retry_with_lax_check(),
300-
)?
313+
)
314+
.map_or_else(
315+
|| {
316+
union_serialize(
317+
json_key_selector,
318+
|v| v.map_or_else(|| infer_json_key(key, extra), Ok),
319+
extra,
320+
&self.choices,
321+
self.retry_with_lax_check(),
322+
)?
323+
},
324+
Ok,
325+
)
301326
}
302327

303328
fn serde_serialize<S: serde::ser::Serializer>(
@@ -308,45 +333,39 @@ impl TypeSerializer for TaggedUnionSerializer {
308333
exclude: Option<&Bound<'_, PyAny>>,
309334
extra: &Extra,
310335
) -> Result<S::Ok, S::Error> {
311-
let py = value.py();
312-
let mut new_extra = extra.clone();
313-
new_extra.check = SerCheck::Strict;
314-
315-
if let Some(tag) = self.get_discriminator_value(value, extra) {
316-
let tag_str = tag.to_string();
317-
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
318-
let selected_serializer = &self.choices[serializer_index];
319-
320-
match selected_serializer.to_python(value, include, exclude, &new_extra) {
321-
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
322-
Err(_) => {
323-
if self.retry_with_lax_check() {
324-
new_extra.check = SerCheck::Lax;
325-
if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) {
326-
return infer_serialize(v.bind(py), serializer, None, None, extra);
327-
}
328-
}
329-
}
330-
}
331-
}
332-
}
336+
let serde_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| {
337+
comb_serializer.to_python(value, include, exclude, new_extra)
338+
};
333339

334-
union_serialize(
335-
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
336-
|v| {
337-
infer_serialize(
338-
v.as_ref().map_or(value, |v| v.bind(value.py())),
339-
serializer,
340-
None,
341-
None,
342-
extra,
343-
)
344-
},
340+
tagged_union_serialize(
341+
None,
342+
&self.lookup,
343+
serde_selector,
345344
extra,
346345
&self.choices,
347346
self.retry_with_lax_check(),
348347
)
349-
.map_err(|err| serde::ser::Error::custom(err.to_string()))?
348+
.map_or_else(
349+
|| {
350+
union_serialize(
351+
serde_selector,
352+
|v| {
353+
infer_serialize(
354+
v.as_ref().map_or(value, |v| v.bind(value.py())),
355+
serializer,
356+
None,
357+
None,
358+
extra,
359+
)
360+
},
361+
extra,
362+
&self.choices,
363+
self.retry_with_lax_check(),
364+
)
365+
.map_err(|err| serde::ser::Error::custom(err.to_string()))?
366+
},
367+
|v| infer_serialize(v.bind(value.py()), serializer, None, None, extra),
368+
)
350369
}
351370

352371
fn get_name(&self) -> &str {

0 commit comments

Comments
 (0)