Skip to content

Commit b7fdf70

Browse files
committed
more union ser tidying
1 parent 2419981 commit b7fdf70

File tree

1 file changed

+22
-9
lines changed
  • src/serializers/type_serializers

1 file changed

+22
-9
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -422,19 +422,32 @@ impl TaggedUnionSerializer {
422422
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
423423
let py = value.py();
424424
let discriminator_value = match &self.discriminator {
425-
Discriminator::LookupKey(lookup_key) => lookup_key
426-
.simple_py_get_attr(value)
427-
.ok()
428-
.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))),
425+
Discriminator::LookupKey(lookup_key) => {
426+
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
427+
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
428+
// at this point. we could be more strict and only do this in lax mode...
429+
let getattr_result = match value.is_instance_of::<PyDict>() {
430+
true => {
431+
let value_dict = value.downcast::<PyDict>().unwrap();
432+
lookup_key.py_get_dict_item(value_dict).ok()
433+
}
434+
false => lookup_key.simple_py_get_attr(value).ok(),
435+
};
436+
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py)))
437+
}
429438
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
430439
};
431440
if discriminator_value.is_none() {
432441
let value_str = truncate_safe_repr(value, None);
433-
extra.warnings.custom_warning(
434-
format!(
435-
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
436-
)
437-
);
442+
443+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
444+
if extra.check == SerCheck::None {
445+
extra.warnings.custom_warning(
446+
format!(
447+
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
448+
)
449+
);
450+
}
438451
}
439452
discriminator_value
440453
}

0 commit comments

Comments
 (0)