diff --git a/src/serializers/type_serializers/enum_.rs b/src/serializers/type_serializers/enum_.rs index d2dc40314..cc5d83b24 100644 --- a/src/serializers/type_serializers/enum_.rs +++ b/src/serializers/type_serializers/enum_.rs @@ -19,6 +19,7 @@ use super::{BuildSerializer, CombinedSerializer, Extra, TypeSerializer}; pub struct EnumSerializer { class: Py, serializer: Option>, + use_enum_values: bool, } impl BuildSerializer for EnumSerializer { @@ -29,18 +30,25 @@ impl BuildSerializer for EnumSerializer { config: Option<&Bound<'_, PyDict>>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { + // existing sub_type logic let sub_type: Option = schema.get_as(intern!(schema.py(), "sub_type"))?; - let serializer = match sub_type.as_deref() { Some("int") => Some(Box::new(IntSerializer::new().into())), Some("str") => Some(Box::new(StrSerializer::new().into())), Some("float") => Some(Box::new(FloatSerializer::new(schema.py(), config)?.into())), - Some(_) => return py_schema_err!("`sub_type` must be one of: 'int', 'str', 'float' or None"), + Some(_) => return py_schema_err!(), None => None, }; + + // Read the `use_enum_values` flag from model_config (default: false) + let use_enum_values = config + .and_then(|cfg| cfg.get_as(intern!(schema.py(), "use_enum_values")).ok().flatten()) + .unwrap_or(false); + Ok(Self { class: schema.get_as_req(intern!(schema.py(), "cls"))?, serializer, + use_enum_values, } .into()) } @@ -58,15 +66,14 @@ impl TypeSerializer for EnumSerializer { ) -> PyResult { let py = value.py(); if value.is_exact_instance(self.class.bind(py)) { - // if we're in JSON mode, we need to get the value attribute and serialize that - if extra.mode.is_json() { + // if the model was configured with use_enum_values=True, unwrap the .value + if self.use_enum_values { let dot_value = value.getattr(intern!(py, "value"))?; - match self.serializer { - Some(ref s) => s.to_python(&dot_value, include, exclude, extra), + match &self.serializer { + Some(s) => s.to_python(&dot_value, include, exclude, extra), None => infer_to_python(&dot_value, include, exclude, extra), } } else { - // if we're not in JSON mode, we assume the value is safe to return directly Ok(value.clone().unbind()) } } else { @@ -78,18 +85,19 @@ impl TypeSerializer for EnumSerializer { fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { let py = key.py(); if key.is_exact_instance(self.class.bind(py)) { - let dot_value = key.getattr(intern!(py, "value"))?; - let k = match self.serializer { - Some(ref s) => s.json_key(&dot_value, extra), - None => infer_json_key(&dot_value, extra), - }?; - // since dot_value is a local reference, we need to allocate it and returned an - // owned variant of cow. - Ok(Cow::Owned(k.into_owned())) - } else { - extra.warnings.on_fallback_py(self.get_name(), key, extra)?; - infer_json_key(key, extra) + // if use_enum_values, unwrap enum to its value before making the JSON key + if self.use_enum_values { + let dot_value = key.getattr(intern!(py, "value"))?; + let k = match &self.serializer { + Some(s) => s.json_key(&dot_value, extra), + None => infer_json_key(&dot_value, extra), + }?; + return Ok(Cow::Owned(k.into_owned())); + } + // otherwise, fall through to default } + extra.warnings.on_fallback_py(self.get_name(), key, extra)?; + infer_json_key(key, extra) } fn serde_serialize( @@ -102,8 +110,8 @@ impl TypeSerializer for EnumSerializer { ) -> Result { if value.is_exact_instance(self.class.bind(value.py())) { let dot_value = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?; - match self.serializer { - Some(ref s) => s.serde_serialize(&dot_value, serializer, include, exclude, extra), + match &self.serializer { + Some(s) => s.serde_serialize(&dot_value, serializer, include, exclude, extra), None => infer_serialize(&dot_value, serializer, include, exclude, extra), } } else { @@ -117,8 +125,8 @@ impl TypeSerializer for EnumSerializer { } fn retry_with_lax_check(&self) -> bool { - match self.serializer { - Some(ref s) => s.retry_with_lax_check(), + match &self.serializer { + Some(s) => s.retry_with_lax_check(), None => false, } }