Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dev-pytest = [
]

marshmallow = [
"marshmallow >= 3.0.0, < 4",
"marshmallow >= 3.0.0, < 5",
"marshmallow-dataclass >= 8.0.0, < 9",
]

Expand Down
91 changes: 56 additions & 35 deletions src/frequenz/quantities/experimental/marshmallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
even in minor or patch releases.
"""

from contextvars import ContextVar
from typing import Any, Type

from marshmallow import Schema, ValidationError, fields
from marshmallow import Schema, ValidationError
from marshmallow.fields import Field

from .._apparent_power import ApparentPower
from .._current import Current
Expand All @@ -29,8 +31,20 @@
from .._temperature import Temperature
from .._voltage import Voltage

serialize_as_string_default: ContextVar[bool] = ContextVar(
"serialize_as_string_default", default=False
)
"""Context variable to control the default serialization format for quantities.

class _QuantityField(fields.Field):
If True, quantities are serialized as strings with units.
If False, quantities are serialized as floats.

This can be overridden on a per-field basis using the `serialize_as_string`
metadata attribute.
"""


class _QuantityField(Field[Quantity]):
"""Custom field for Quantity objects supporting per-field serialization configuration.

This class handles serialization and deserialization of ALL Quantity
Expand All @@ -57,24 +71,34 @@ class _QuantityField(fields.Field):
field_type: Type[Quantity] | None = None
"""The specific Quantity subclass."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the field."""
self.serialize_as_string_override = kwargs.pop("serialize_as_string", None)
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The serialize_as_string parameter is being popped from kwargs but this field doesn't appear to be documented or validated. Consider adding proper validation or documentation for this parameter to clarify its expected usage.

Suggested change
self.serialize_as_string_override = kwargs.pop("serialize_as_string", None)
"""
Initialize the field.
Args:
*args: Positional arguments passed to the base Field.
**kwargs: Keyword arguments passed to the base Field.
serialize_as_string (bool, optional): If set, overrides the default
serialization format for this field. If True, the field will
serialize as a string with units; if False, as a float. If not
provided, the default from the context variable is used.
Raises:
TypeError: If `serialize_as_string` is not a boolean or None.
"""
self.serialize_as_string_override = kwargs.pop("serialize_as_string", None)
if (
self.serialize_as_string_override is not None
and not isinstance(self.serialize_as_string_override, bool)
):
raise TypeError(
f"serialize_as_string must be a boolean or None, got {type(self.serialize_as_string_override).__name__}"
)

Copilot uses AI. Check for mistakes.
super().__init__(*args, **kwargs)

def _serialize(
self, value: Quantity, attr: str | None, obj: Any, **kwargs: Any
self, value: Quantity | None, attr: str | None, obj: Any, **kwargs: Any
) -> Any:
"""Serialize the Quantity object based on per-field configuration."""
if self.field_type is None or not issubclass(self.field_type, Quantity):
raise TypeError(
"field_type must be set to a Quantity subclass in the subclass."
)
if value is None:
return None

assert self.parent is not None
if not isinstance(value, Quantity):
raise TypeError(
f"Expected a Quantity object, but got {type(value).__name__}."
)

# Determine the serialization format
default = (
False
if self.parent.context is None
else self.parent.context.get("serialize_as_string_default", False)
default = serialize_as_string_default.get()
serialize_as_string = (
self.serialize_as_string_override
if self.serialize_as_string_override is not None
else default
)
serialize_as_string = self.metadata.get("serialize_as_string", default)

if serialize_as_string:
# Use the Quantity's native string representation (includes unit)
Expand Down Expand Up @@ -177,7 +201,7 @@ class VoltageField(_QuantityField):
field_type = Voltage


QUANTITY_FIELD_CLASSES: dict[type[Quantity], type[fields.Field]] = {
QUANTITY_FIELD_CLASSES: dict[type[Quantity], type[Field[Any]]] = {
ApparentPower: ApparentPowerField,
Current: CurrentField,
Energy: EnergyField,
Expand Down Expand Up @@ -208,8 +232,10 @@ class QuantitySchema(Schema):
from marshmallow_dataclass import class_schema
from marshmallow.validate import Range
from frequenz.quantities import Percentage
from frequenz.quantities.experimental.marshmallow import QuantitySchema
from typing import cast
from frequenz.quantities.experimental.marshmallow import (
QuantitySchema,
serialize_as_string_default,
)

@dataclass
class Config:
Expand Down Expand Up @@ -245,29 +271,24 @@ class Config:
},
)

@classmethod
def load(cls, config: dict[str, Any]) -> "Config":
schema = class_schema(cls, base_schema=QuantitySchema)(
serialize_as_string_default=True # type: ignore[call-arg]
)
return cast(Config, schema.load(config))
config_obj = Config()
Schema = class_schema(Config, base_schema=QuantitySchema)
schema = Schema()

# Default serialization (as float)
result = schema.dump(config_obj)
assert result["percentage_serialized_as_schema_default"] == 25.0

# Override default serialization to string
serialize_as_string_default.set(True)
result = schema.dump(config_obj)
assert result["percentage_serialized_as_schema_default"] == "25.0 %"
serialize_as_string_default.set(False) # Reset context
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting the context variable in example code without proper cleanup could lead to unexpected behavior if the example is copy-pasted. Consider using a context manager or explicitly mentioning the need to reset the context variable in production code.

Suggested change
serialize_as_string_default.set(False) # Reset context
token = serialize_as_string_default.set(True)
result = schema.dump(config_obj)
assert result["percentage_serialized_as_schema_default"] == "25.0 %"
serialize_as_string_default.reset(token) # Reset context

Copilot uses AI. Check for mistakes.

# Per-field configuration always takes precedence
assert result["percentage_always_as_string"] == "25.0 %"
assert result["percentage_always_as_float"] == 25.0
```
"""

TYPE_MAPPING: dict[type[Quantity], type[fields.Field]] = QUANTITY_FIELD_CLASSES

def __init__(
self, *args: Any, serialize_as_string_default: bool = False, **kwargs: Any
) -> None:
"""
Initialize the schema with a default serialization format.

Args:
*args: Additional positional arguments.
serialize_as_string_default: Default serialization format for quantities.
If True, quantities are serialized as strings with units.
If False, quantities are serialized as floats.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
self.context["serialize_as_string_default"] = serialize_as_string_default
TYPE_MAPPING: dict[type, type[Field[Any]]] = QUANTITY_FIELD_CLASSES
24 changes: 13 additions & 11 deletions tests/experimental/test_marshmallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
Temperature,
Voltage,
)
from frequenz.quantities.experimental.marshmallow import QuantitySchema
from frequenz.quantities.experimental.marshmallow import (
QuantitySchema,
serialize_as_string_default,
)


@dataclass
Expand Down Expand Up @@ -74,19 +77,19 @@ class Config:
default_factory=lambda: Voltage.from_kilovolts(200.0),
metadata={
"metadata": {
"description": "A voltage field that is always serialized as a string",
"serialize_as_string": True,
"description": "A voltage field that is always serialized as a string"
},
"serialize_as_string": True,
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metadata structure has been changed to move serialize_as_string outside of the nested metadata dict. This inconsistency with the previous structure could cause confusion - consider documenting this change or ensuring consistency across all field definitions.

Suggested change
"serialize_as_string": True,
"description": "A voltage field that is always serialized as a string",
"serialize_as_string": True,
},

Copilot uses AI. Check for mistakes.
},
)

temp_never_string: Temperature = field(
default_factory=lambda: Temperature.from_celsius(100.0),
metadata={
"metadata": {
"description": "A temperature field that is never serialized as a string",
"serialize_as_string": False,
"description": "A temperature field that is never serialized as a string"
},
"serialize_as_string": False,
},
)

Expand All @@ -96,11 +99,10 @@ def load(cls, config: dict[str, Any]) -> Self:
schema = class_schema(cls, base_schema=QuantitySchema)()
return cast(Self, schema.load(config))

def dump(self, serialize_as_string_default: bool = False) -> dict[str, Any]:
def dump(self, use_string: bool = False) -> dict[str, Any]:
"""Dump the configuration."""
schema = class_schema(Config, base_schema=QuantitySchema)(
serialize_as_string_default=serialize_as_string_default # type: ignore[call-arg]
)
schema = class_schema(Config, base_schema=QuantitySchema)()
serialize_as_string_default.set(use_string)
return cast(dict[str, Any], schema.dump(self))


Expand Down Expand Up @@ -208,7 +210,7 @@ def test_config_schema_dump_default_float() -> None:
temp_never_string=Temperature.from_celsius(10.0),
)

dumped = config.dump(serialize_as_string_default=False)
dumped = config.dump(use_string=False)

assert dumped == {
"my_percent_field": 50.0,
Expand All @@ -233,7 +235,7 @@ def test_config_schema_dump_default_string() -> None:
temp_never_string=Temperature.from_celsius(10.0),
)

dumped = config.dump(serialize_as_string_default=True)
dumped = config.dump(use_string=True)

assert dumped == {
"my_percent_field": "50 %",
Expand Down