diff --git a/pyproject.toml b/pyproject.toml index 9498ac8..dc1334a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ dev-pytest = [ ] marshmallow = [ - "marshmallow >= 3.0.0, < 4", + "marshmallow >= 3.0.0, < 5", "marshmallow-dataclass >= 8.0.0, < 9", ] diff --git a/src/frequenz/quantities/experimental/marshmallow.py b/src/frequenz/quantities/experimental/marshmallow.py index 483bbdf..9bb28e7 100644 --- a/src/frequenz/quantities/experimental/marshmallow.py +++ b/src/frequenz/quantities/experimental/marshmallow.py @@ -14,9 +14,11 @@ even in minor or patch releases. """ +from contextvars import ContextVar from typing import Any, Type -from marshmallow import Schema, ValidationError, fields +from marshmallow import Schema, ValidationError +from marshmallow.fields import Field from .._apparent_power import ApparentPower from .._current import Current @@ -29,8 +31,20 @@ from .._temperature import Temperature from .._voltage import Voltage +serialize_as_string_default: ContextVar[bool] = ContextVar( + "serialize_as_string_default", default=False +) +"""Context variable to control the default serialization format for quantities. -class _QuantityField(fields.Field): +If True, quantities are serialized as strings with units. +If False, quantities are serialized as floats. + +This can be overridden on a per-field basis using the `serialize_as_string` +metadata attribute. +""" + + +class _QuantityField(Field[Quantity]): """Custom field for Quantity objects supporting per-field serialization configuration. This class handles serialization and deserialization of ALL Quantity @@ -57,24 +71,34 @@ class _QuantityField(fields.Field): field_type: Type[Quantity] | None = None """The specific Quantity subclass.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the field.""" + self.serialize_as_string_override = kwargs.pop("serialize_as_string", None) + super().__init__(*args, **kwargs) + def _serialize( - self, value: Quantity, attr: str | None, obj: Any, **kwargs: Any + self, value: Quantity | None, attr: str | None, obj: Any, **kwargs: Any ) -> Any: """Serialize the Quantity object based on per-field configuration.""" if self.field_type is None or not issubclass(self.field_type, Quantity): raise TypeError( "field_type must be set to a Quantity subclass in the subclass." ) + if value is None: + return None - assert self.parent is not None + if not isinstance(value, Quantity): + raise TypeError( + f"Expected a Quantity object, but got {type(value).__name__}." + ) # Determine the serialization format - default = ( - False - if self.parent.context is None - else self.parent.context.get("serialize_as_string_default", False) + default = serialize_as_string_default.get() + serialize_as_string = ( + self.serialize_as_string_override + if self.serialize_as_string_override is not None + else default ) - serialize_as_string = self.metadata.get("serialize_as_string", default) if serialize_as_string: # Use the Quantity's native string representation (includes unit) @@ -177,7 +201,7 @@ class VoltageField(_QuantityField): field_type = Voltage -QUANTITY_FIELD_CLASSES: dict[type[Quantity], type[fields.Field]] = { +QUANTITY_FIELD_CLASSES: dict[type[Quantity], type[Field[Any]]] = { ApparentPower: ApparentPowerField, Current: CurrentField, Energy: EnergyField, @@ -208,8 +232,10 @@ class QuantitySchema(Schema): from marshmallow_dataclass import class_schema from marshmallow.validate import Range from frequenz.quantities import Percentage - from frequenz.quantities.experimental.marshmallow import QuantitySchema - from typing import cast + from frequenz.quantities.experimental.marshmallow import ( + QuantitySchema, + serialize_as_string_default, + ) @dataclass class Config: @@ -245,29 +271,24 @@ class Config: }, ) - @classmethod - def load(cls, config: dict[str, Any]) -> "Config": - schema = class_schema(cls, base_schema=QuantitySchema)( - serialize_as_string_default=True # type: ignore[call-arg] - ) - return cast(Config, schema.load(config)) + config_obj = Config() + Schema = class_schema(Config, base_schema=QuantitySchema) + schema = Schema() + + # Default serialization (as float) + result = schema.dump(config_obj) + assert result["percentage_serialized_as_schema_default"] == 25.0 + + # Override default serialization to string + serialize_as_string_default.set(True) + result = schema.dump(config_obj) + assert result["percentage_serialized_as_schema_default"] == "25.0 %" + serialize_as_string_default.set(False) # Reset context + + # Per-field configuration always takes precedence + assert result["percentage_always_as_string"] == "25.0 %" + assert result["percentage_always_as_float"] == 25.0 ``` """ - TYPE_MAPPING: dict[type[Quantity], type[fields.Field]] = QUANTITY_FIELD_CLASSES - - def __init__( - self, *args: Any, serialize_as_string_default: bool = False, **kwargs: Any - ) -> None: - """ - Initialize the schema with a default serialization format. - - Args: - *args: Additional positional arguments. - serialize_as_string_default: Default serialization format for quantities. - If True, quantities are serialized as strings with units. - If False, quantities are serialized as floats. - **kwargs: Additional keyword arguments. - """ - super().__init__(*args, **kwargs) - self.context["serialize_as_string_default"] = serialize_as_string_default + TYPE_MAPPING: dict[type, type[Field[Any]]] = QUANTITY_FIELD_CLASSES diff --git a/tests/experimental/test_marshmallow.py b/tests/experimental/test_marshmallow.py index d67eabf..2b880e7 100644 --- a/tests/experimental/test_marshmallow.py +++ b/tests/experimental/test_marshmallow.py @@ -18,7 +18,10 @@ Temperature, Voltage, ) -from frequenz.quantities.experimental.marshmallow import QuantitySchema +from frequenz.quantities.experimental.marshmallow import ( + QuantitySchema, + serialize_as_string_default, +) @dataclass @@ -74,9 +77,9 @@ class Config: default_factory=lambda: Voltage.from_kilovolts(200.0), metadata={ "metadata": { - "description": "A voltage field that is always serialized as a string", - "serialize_as_string": True, + "description": "A voltage field that is always serialized as a string" }, + "serialize_as_string": True, }, ) @@ -84,9 +87,9 @@ class Config: default_factory=lambda: Temperature.from_celsius(100.0), metadata={ "metadata": { - "description": "A temperature field that is never serialized as a string", - "serialize_as_string": False, + "description": "A temperature field that is never serialized as a string" }, + "serialize_as_string": False, }, ) @@ -96,11 +99,10 @@ def load(cls, config: dict[str, Any]) -> Self: schema = class_schema(cls, base_schema=QuantitySchema)() return cast(Self, schema.load(config)) - def dump(self, serialize_as_string_default: bool = False) -> dict[str, Any]: + def dump(self, use_string: bool = False) -> dict[str, Any]: """Dump the configuration.""" - schema = class_schema(Config, base_schema=QuantitySchema)( - serialize_as_string_default=serialize_as_string_default # type: ignore[call-arg] - ) + schema = class_schema(Config, base_schema=QuantitySchema)() + serialize_as_string_default.set(use_string) return cast(dict[str, Any], schema.dump(self)) @@ -208,7 +210,7 @@ def test_config_schema_dump_default_float() -> None: temp_never_string=Temperature.from_celsius(10.0), ) - dumped = config.dump(serialize_as_string_default=False) + dumped = config.dump(use_string=False) assert dumped == { "my_percent_field": 50.0, @@ -233,7 +235,7 @@ def test_config_schema_dump_default_string() -> None: temp_never_string=Temperature.from_celsius(10.0), ) - dumped = config.dump(serialize_as_string_default=True) + dumped = config.dump(use_string=True) assert dumped == { "my_percent_field": "50 %",