Skip to content

Commit eb121d9

Browse files
committed
Add a generic enum_from_proto() function
This function can convert any `int` to any `Enum`, with optional validation or forward-compatibility. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 10d81cb commit eb121d9

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# License: MIT
2+
# Copyright © 2025 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Conversion of protobuf int enums to Python enums."""
5+
6+
import enum
7+
from typing import Literal, TypeVar, overload
8+
9+
EnumT = TypeVar("EnumT", bound=enum.Enum)
10+
"""A type variable that is bound to an enum."""
11+
12+
13+
@overload
14+
def enum_from_proto(
15+
value: int, enum_type: type[EnumT], *, allow_invalid: Literal[False]
16+
) -> EnumT: ...
17+
18+
19+
@overload
20+
def enum_from_proto(
21+
value: int, enum_type: type[EnumT], *, allow_invalid: Literal[True] = True
22+
) -> EnumT | int: ...
23+
24+
25+
def enum_from_proto(
26+
value: int, enum_type: type[EnumT], *, allow_invalid: bool = True
27+
) -> EnumT | int:
28+
"""Convert a protobuf int enum value to a python enum.
29+
30+
Example:
31+
```python
32+
import enum
33+
34+
from proto import proto_pb2 # Just an example. pylint: disable=import-error
35+
36+
@enum.unique
37+
class SomeEnum(enum.Enum):
38+
# These values should match the protobuf enum values.
39+
UNSPECIFIED = 0
40+
SOME_VALUE = 1
41+
42+
enum_value = enum_from_proto(proto_pb2.SomeEnum.SOME_ENUM_SOME_VALUE, SomeEnum)
43+
# -> SomeEnum.SOME_VALUE
44+
45+
enum_value = enum_from_proto(42, SomeEnum)
46+
# -> 42
47+
48+
enum_value = enum_from_proto(
49+
proto_pb2.SomeEnum.SOME_ENUM_UNKNOWN_VALUE, SomeEnum, allow_invalid=False
50+
)
51+
# -> ValueError
52+
```
53+
54+
Args:
55+
value: The protobuf int enum value.
56+
enum_type: The python enum type to convert to.
57+
allow_invalid: If `True`, return the value as an `int` if the value is not
58+
a valid member of the enum (this allows for forward-compatibility with new
59+
enum values defined in the protocol but not added to the Python enum yet).
60+
If `False`, raise a `ValueError` if the value is not a valid member of the
61+
enum.
62+
63+
Returns:
64+
The resulting python enum value if the protobuf value is known, otherwise
65+
the input value converted to a plain `int`.
66+
67+
Raises:
68+
ValueError: If `allow_invalid` is `False` and the value is not a valid member
69+
of the enum.
70+
"""
71+
try:
72+
return enum_type(value)
73+
except ValueError:
74+
if allow_invalid:
75+
return value
76+
raise

tests/test_enum_proto.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# License: MIT
2+
# Copyright © 2025 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for enum_from_proto utility."""
5+
6+
import enum
7+
8+
import pytest
9+
10+
from frequenz.client.common.enum_proto import enum_from_proto
11+
12+
13+
class _TestEnum(enum.Enum):
14+
"""A test enum for enum_from_proto tests."""
15+
16+
ZERO = 0
17+
ONE = 1
18+
TWO = 2
19+
20+
21+
@pytest.mark.parametrize("enum_member", _TestEnum)
22+
def test_valid_allow_invalid(enum_member: _TestEnum) -> None:
23+
"""Test conversion of valid enum values."""
24+
assert enum_from_proto(enum_member.value, _TestEnum) == enum_member
25+
assert (
26+
enum_from_proto(enum_member.value, _TestEnum, allow_invalid=True) == enum_member
27+
)
28+
29+
30+
@pytest.mark.parametrize("value", [42, -1])
31+
def test_invalid_allow_invalid(value: int) -> None:
32+
"""Test unknown values with allow_invalid=True (default)."""
33+
assert enum_from_proto(value, _TestEnum) == value
34+
assert enum_from_proto(value, _TestEnum, allow_invalid=True) == value
35+
36+
37+
@pytest.mark.parametrize("enum_member", _TestEnum)
38+
def test_valid_disallow_invalid(enum_member: _TestEnum) -> None:
39+
"""Test unknown values with allow_invalid=False (should raise ValueError)."""
40+
assert (
41+
enum_from_proto(enum_member.value, _TestEnum, allow_invalid=False)
42+
== enum_member
43+
)
44+
45+
46+
@pytest.mark.parametrize("value", [42, -1])
47+
def test_invalid_disallow(value: int) -> None:
48+
"""Test unknown values with allow_invalid=False (should raise ValueError)."""
49+
with pytest.raises(ValueError, match=rf"^{value} is not a valid _TestEnum$"):
50+
enum_from_proto(value, _TestEnum, allow_invalid=False)

0 commit comments

Comments
 (0)