Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions src/serializers/type_serializers/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use super::{BuildSerializer, CombinedSerializer, Extra, TypeSerializer};
pub struct EnumSerializer {
class: Py<PyType>,
serializer: Option<Box<CombinedSerializer>>,
use_enum_values: bool,
}

impl BuildSerializer for EnumSerializer {
Expand All @@ -29,18 +30,25 @@ impl BuildSerializer for EnumSerializer {
config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
// existing sub_type logic
let sub_type: Option<String> = schema.get_as(intern!(schema.py(), "sub_type"))?;

let serializer = match sub_type.as_deref() {
Some("int") => Some(Box::new(IntSerializer::new().into())),
Some("str") => Some(Box::new(StrSerializer::new().into())),
Some("float") => Some(Box::new(FloatSerializer::new(schema.py(), config)?.into())),
Some(_) => return py_schema_err!("`sub_type` must be one of: 'int', 'str', 'float' or None"),
Some(_) => return py_schema_err!(),
None => None,
};

// Read the `use_enum_values` flag from model_config (default: false)
let use_enum_values = config
.and_then(|cfg| cfg.get_as(intern!(schema.py(), "use_enum_values")).ok().flatten())
.unwrap_or(false);

Ok(Self {
class: schema.get_as_req(intern!(schema.py(), "cls"))?,
serializer,
use_enum_values,
}
.into())
}
Expand All @@ -58,15 +66,14 @@ impl TypeSerializer for EnumSerializer {
) -> PyResult<PyObject> {
let py = value.py();
if value.is_exact_instance(self.class.bind(py)) {
// if we're in JSON mode, we need to get the value attribute and serialize that
if extra.mode.is_json() {
// if the model was configured with use_enum_values=True, unwrap the .value
if self.use_enum_values {
let dot_value = value.getattr(intern!(py, "value"))?;
match self.serializer {
Some(ref s) => s.to_python(&dot_value, include, exclude, extra),
match &self.serializer {
Some(s) => s.to_python(&dot_value, include, exclude, extra),
None => infer_to_python(&dot_value, include, exclude, extra),
}
} else {
// if we're not in JSON mode, we assume the value is safe to return directly
Ok(value.clone().unbind())
}
} else {
Expand All @@ -78,18 +85,19 @@ impl TypeSerializer for EnumSerializer {
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
let py = key.py();
if key.is_exact_instance(self.class.bind(py)) {
let dot_value = key.getattr(intern!(py, "value"))?;
let k = match self.serializer {
Some(ref s) => s.json_key(&dot_value, extra),
None => infer_json_key(&dot_value, extra),
}?;
// since dot_value is a local reference, we need to allocate it and returned an
// owned variant of cow.
Ok(Cow::Owned(k.into_owned()))
} else {
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
// if use_enum_values, unwrap enum to its value before making the JSON key
if self.use_enum_values {
let dot_value = key.getattr(intern!(py, "value"))?;
let k = match &self.serializer {
Some(s) => s.json_key(&dot_value, extra),
None => infer_json_key(&dot_value, extra),
}?;
return Ok(Cow::Owned(k.into_owned()));
}
// otherwise, fall through to default
}
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -102,8 +110,8 @@ impl TypeSerializer for EnumSerializer {
) -> Result<S::Ok, S::Error> {
if value.is_exact_instance(self.class.bind(value.py())) {
let dot_value = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
match self.serializer {
Some(ref s) => s.serde_serialize(&dot_value, serializer, include, exclude, extra),
match &self.serializer {
Some(s) => s.serde_serialize(&dot_value, serializer, include, exclude, extra),
None => infer_serialize(&dot_value, serializer, include, exclude, extra),
}
} else {
Expand All @@ -117,8 +125,8 @@ impl TypeSerializer for EnumSerializer {
}

fn retry_with_lax_check(&self) -> bool {
match self.serializer {
Some(ref s) => s.retry_with_lax_check(),
match &self.serializer {
Some(s) => s.retry_with_lax_check(),
None => false,
}
}
Expand Down
Loading