Skip to content

Commit 60a9acb

Browse files
committed
Use decorator for error handling in from_pb methods
To support logging an error in case there are problems when converting the protobuf message. Signed-off-by: cwasicki <[email protected]>
1 parent aed361e commit 60a9acb

File tree

1 file changed

+71
-64
lines changed
  • src/frequenz/client/electricity_trading

1 file changed

+71
-64
lines changed

src/frequenz/client/electricity_trading/_types.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from dataclasses import dataclass
1313
from datetime import datetime, timedelta, timezone
1414
from decimal import Decimal
15-
from typing import Self
15+
from functools import wraps
16+
from typing import Any, Callable, Self, Type, TypeVar
1617

1718
# pylint: disable=no-member
1819
from frequenz.api.common.v1.grid import delivery_area_pb2, delivery_duration_pb2
@@ -24,6 +25,32 @@
2425
_logger = logging.getLogger(__name__)
2526

2627

28+
T = TypeVar("T") # Generic type variable for class methods
29+
30+
31+
def from_pb(func: Callable[[Type[T], Any], T]) -> Callable[[Type[T], Any], T]:
32+
"""Standardize from_pb methods like error handling with this decorator.
33+
34+
Args:
35+
func: A class method that converts a protobuf message into an object.
36+
37+
Returns:
38+
The wrapped function with standardized error handling.
39+
"""
40+
41+
@wraps(func)
42+
def wrapper(cls: Type[T], pb_obj: Any) -> T:
43+
try:
44+
return func(cls, pb_obj)
45+
except Exception as e:
46+
_logger.error(
47+
"Error converting %s from protobuf (`%s`): %s", cls.__name__, pb_obj, e
48+
)
49+
raise
50+
51+
return wrapper
52+
53+
2754
# From frequanz.api.common
2855
class Currency(enum.Enum):
2956
"""
@@ -91,6 +118,7 @@ class Price:
91118
"""Currency of the price."""
92119

93120
@classmethod
121+
@from_pb
94122
def from_pb(cls, price: price_pb2.Price) -> Self:
95123
"""Convert a protobuf Price to Price object.
96124
@@ -99,18 +127,11 @@ def from_pb(cls, price: price_pb2.Price) -> Self:
99127
100128
Returns:
101129
Price object corresponding to the protobuf message.
102-
103-
Raises:
104-
Exception: If an error occurs during conversion.
105130
"""
106-
try:
107-
return cls(
108-
amount=Decimal(price.amount.value),
109-
currency=Currency.from_pb(price.currency),
110-
)
111-
except Exception as e:
112-
_logger.error("Error converting price `%s`: %s", price, e)
113-
raise
131+
return cls(
132+
amount=Decimal(price.amount.value),
133+
currency=Currency.from_pb(price.currency),
134+
)
114135

115136
def to_pb(self) -> price_pb2.Price:
116137
"""Convert a Price object to protobuf Price.
@@ -138,6 +159,7 @@ class Power:
138159
mw: Decimal
139160

140161
@classmethod
162+
@from_pb
141163
def from_pb(cls, power: power_pb2.Power) -> Self:
142164
"""Convert a protobuf Power to Power object.
143165
@@ -146,15 +168,8 @@ def from_pb(cls, power: power_pb2.Power) -> Self:
146168
147169
Returns:
148170
Power object corresponding to the protobuf message.
149-
150-
Raises:
151-
Exception: If an error occurs during conversion.
152171
"""
153-
try:
154-
return cls(mw=Decimal(power.mw.value))
155-
except Exception as e:
156-
_logger.error("Error converting power `%s`: %s", power, e)
157-
raise
172+
return cls(mw=Decimal(power.mw.value))
158173

159174
def to_pb(self) -> power_pb2.Power:
160175
"""Convert a Power object to protobuf Power.
@@ -965,6 +980,7 @@ def __post_init__(self) -> None:
965980
self.valid_until = self.valid_until.astimezone(timezone.utc)
966981

967982
@classmethod
983+
@from_pb
968984
def from_pb(cls, order: electricity_trading_pb2.Order) -> Self:
969985
"""Convert a protobuf Order to Order object.
970986
@@ -973,51 +989,42 @@ def from_pb(cls, order: electricity_trading_pb2.Order) -> Self:
973989
974990
Returns:
975991
Order object corresponding to the protobuf message.
976-
977-
Raises:
978-
Exception: If an error occurs during conversion.
979992
"""
980-
try:
981-
return cls(
982-
delivery_area=DeliveryArea.from_pb(order.delivery_area),
983-
delivery_period=DeliveryPeriod.from_pb(order.delivery_period),
984-
type=OrderType.from_pb(order.type),
985-
side=MarketSide.from_pb(order.side),
986-
price=Price.from_pb(order.price),
987-
quantity=Power.from_pb(order.quantity),
988-
stop_price=(
989-
Price.from_pb(order.stop_price)
990-
if order.HasField("stop_price")
991-
else None
992-
),
993-
peak_price_delta=(
994-
Price.from_pb(order.peak_price_delta)
995-
if order.HasField("peak_price_delta")
996-
else None
997-
),
998-
display_quantity=(
999-
Power.from_pb(order.display_quantity)
1000-
if order.HasField("display_quantity")
1001-
else None
1002-
),
1003-
execution_option=(
1004-
OrderExecutionOption.from_pb(order.execution_option)
1005-
if order.HasField("execution_option")
1006-
else None
1007-
),
1008-
valid_until=(
1009-
order.valid_until.ToDatetime(tzinfo=timezone.utc)
1010-
if order.HasField("valid_until")
1011-
else None
1012-
),
1013-
payload=(
1014-
json_format.MessageToDict(order.payload) if order.payload else None
1015-
),
1016-
tag=order.tag if order.tag else None,
1017-
)
1018-
except Exception as e:
1019-
_logger.error("Error converting order `%s`: %s", order, e)
1020-
raise
993+
return cls(
994+
delivery_area=DeliveryArea.from_pb(order.delivery_area),
995+
delivery_period=DeliveryPeriod.from_pb(order.delivery_period),
996+
type=OrderType.from_pb(order.type),
997+
side=MarketSide.from_pb(order.side),
998+
price=Price.from_pb(order.price),
999+
quantity=Power.from_pb(order.quantity),
1000+
stop_price=(
1001+
Price.from_pb(order.stop_price)
1002+
if order.HasField("stop_price")
1003+
else None
1004+
),
1005+
peak_price_delta=(
1006+
Price.from_pb(order.peak_price_delta)
1007+
if order.HasField("peak_price_delta")
1008+
else None
1009+
),
1010+
display_quantity=(
1011+
Power.from_pb(order.display_quantity)
1012+
if order.HasField("display_quantity")
1013+
else None
1014+
),
1015+
execution_option=(
1016+
OrderExecutionOption.from_pb(order.execution_option)
1017+
if order.HasField("execution_option")
1018+
else None
1019+
),
1020+
valid_until=(
1021+
order.valid_until.ToDatetime(tzinfo=timezone.utc)
1022+
if order.HasField("valid_until")
1023+
else None
1024+
),
1025+
payload=json_format.MessageToDict(order.payload) if order.payload else None,
1026+
tag=order.tag if order.tag else None,
1027+
)
10211028

10221029
def to_pb(self) -> electricity_trading_pb2.Order:
10231030
"""

0 commit comments

Comments
 (0)