Skip to content

Commit 26e69e2

Browse files
authored
Add a new Interval class based on the SDK Bounds (#19)
This new class is based on the SDK `Bounds` class. It is renamed to `Interval` partly because this is a more generic repo and this is a more generic name, but also because we will add `Bounds` class in the future to represent multiple (inclusion) intervals, which is how the new microgrid API handle bounds. Fixes #11, #12, #13, #15.
2 parents 4587753 + b013105 commit 26e69e2

File tree

2 files changed

+237
-0
lines changed

2 files changed

+237
-0
lines changed

src/frequenz/core/collections.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# License: MIT
2+
# Copyright © 2022 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Data structures that contain collections of values or objects."""
5+
6+
from dataclasses import dataclass
7+
from typing import Generic, Protocol, Self, TypeVar, cast
8+
9+
10+
class LessThanComparable(Protocol):
11+
"""A protocol that requires the `__lt__` method to compare values."""
12+
13+
def __lt__(self, other: Self, /) -> bool:
14+
"""Return whether self is less than other."""
15+
16+
17+
LessThanComparableOrNoneT = TypeVar(
18+
"LessThanComparableOrNoneT", bound=LessThanComparable | None
19+
)
20+
"""Type variable for a value that a `LessThanComparable` or `None`."""
21+
22+
23+
@dataclass(frozen=True)
24+
class Interval(Generic[LessThanComparableOrNoneT]):
25+
"""An interval to test if a value is within its limits.
26+
27+
The `start` and `end` are inclusive, meaning that the `start` and `end` limites are
28+
included in the range when checking if a value is contained by the interval.
29+
30+
If the `start` or `end` is `None`, it means that the interval is unbounded in that
31+
direction.
32+
33+
If `start` is bigger than `end`, a `ValueError` is raised.
34+
35+
The type stored in the interval must be comparable, meaning that it must implement
36+
the `__lt__` method to be able to compare values.
37+
"""
38+
39+
start: LessThanComparableOrNoneT
40+
"""The start of the interval."""
41+
42+
end: LessThanComparableOrNoneT
43+
"""The end of the interval."""
44+
45+
def __post_init__(self) -> None:
46+
"""Check if the start is less than or equal to the end."""
47+
if self.start is None or self.end is None:
48+
return
49+
start = cast(LessThanComparable, self.start)
50+
end = cast(LessThanComparable, self.end)
51+
if start > end:
52+
raise ValueError(
53+
f"The start ({self.start}) can't be bigger than end ({self.end})"
54+
)
55+
56+
def __contains__(self, item: LessThanComparableOrNoneT) -> bool:
57+
"""
58+
Check if the value is within the range of the container.
59+
60+
Args:
61+
item: The value to check.
62+
63+
Returns:
64+
bool: True if value is within the range, otherwise False.
65+
"""
66+
if item is None:
67+
return False
68+
casted_item = cast(LessThanComparable, item)
69+
70+
if self.start is None and self.end is None:
71+
return True
72+
if self.start is None:
73+
start = cast(LessThanComparable, self.end)
74+
return not casted_item > start
75+
if self.end is None:
76+
return not self.start > item
77+
# mypy seems to get confused here, not being able to narrow start and end to
78+
# just LessThanComparable, complaining with:
79+
# error: Unsupported left operand type for <= (some union)
80+
# But we know if they are not None, they should be LessThanComparable, and
81+
# actually mypy is being able to figure it out in the lines above, just not in
82+
# this one, so it should be safe to cast.
83+
return not (
84+
casted_item < cast(LessThanComparable, self.start)
85+
or casted_item > cast(LessThanComparable, self.end)
86+
)

tests/test_collections.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for the collections module."""
5+
6+
7+
from typing import Self
8+
9+
import pytest
10+
11+
from frequenz.core.collections import Interval, LessThanComparable
12+
13+
14+
class CustomComparable:
15+
"""A custom comparable class."""
16+
17+
def __init__(self, value: int) -> None:
18+
"""Initialize this instance."""
19+
self.value = value
20+
21+
def __lt__(self, other: Self) -> bool:
22+
"""Return whether this instance is less than other."""
23+
return self.value < other.value
24+
25+
def __eq__(self, other: object) -> bool:
26+
"""Return whether this instance is equal to other."""
27+
if not isinstance(other, CustomComparable):
28+
return False
29+
return self.value == other.value
30+
31+
def __repr__(self) -> str:
32+
"""Return a string representation of this instance."""
33+
return str(self.value)
34+
35+
36+
@pytest.mark.parametrize(
37+
"start, end",
38+
[
39+
(10.0, -100.0),
40+
(CustomComparable(10), CustomComparable(-100)),
41+
],
42+
)
43+
def test_invalid_range(start: LessThanComparable, end: LessThanComparable) -> None:
44+
"""Test if the interval has an invalid range."""
45+
with pytest.raises(
46+
ValueError,
47+
match=rf"The start \({start}\) can't be bigger than end \({end}\)",
48+
):
49+
Interval(start, end)
50+
51+
52+
@pytest.mark.parametrize(
53+
"start, end, within, at_start, at_end, before_start, after_end",
54+
[
55+
(10.0, 100.0, 50.0, 10.0, 100.0, 9.0, 101.0),
56+
(
57+
CustomComparable(10),
58+
CustomComparable(100),
59+
CustomComparable(50),
60+
CustomComparable(10),
61+
CustomComparable(100),
62+
CustomComparable(9),
63+
CustomComparable(101),
64+
),
65+
],
66+
)
67+
def test_interval_contains( # pylint: disable=too-many-arguments
68+
start: LessThanComparable,
69+
end: LessThanComparable,
70+
within: LessThanComparable,
71+
at_start: LessThanComparable,
72+
at_end: LessThanComparable,
73+
before_start: LessThanComparable,
74+
after_end: LessThanComparable,
75+
) -> None:
76+
"""Test if a value is within the interval."""
77+
interval = Interval(start=start, end=end)
78+
assert within in interval # within
79+
assert at_start in interval # at start
80+
assert at_end in interval # at end
81+
assert before_start not in interval # before start
82+
assert after_end not in interval # after end
83+
84+
85+
@pytest.mark.parametrize(
86+
"end, within, at_end, after_end",
87+
[
88+
(100.0, 50.0, 100.0, 101.0),
89+
(
90+
CustomComparable(100),
91+
CustomComparable(50),
92+
CustomComparable(100),
93+
CustomComparable(101),
94+
),
95+
],
96+
)
97+
def test_interval_contains_no_start(
98+
end: LessThanComparable,
99+
within: LessThanComparable,
100+
at_end: LessThanComparable,
101+
after_end: LessThanComparable,
102+
) -> None:
103+
"""Test if a value is within the interval with no start."""
104+
interval_no_start = Interval(start=None, end=end)
105+
assert within in interval_no_start # within end
106+
assert at_end in interval_no_start # at end
107+
assert after_end not in interval_no_start # after end
108+
109+
110+
@pytest.mark.parametrize(
111+
"start, within, at_start, before_start",
112+
[
113+
(10.0, 50.0, 10.0, 9.0),
114+
(
115+
CustomComparable(10),
116+
CustomComparable(50),
117+
CustomComparable(10),
118+
CustomComparable(9),
119+
),
120+
],
121+
)
122+
def test_interval_contains_no_end(
123+
start: LessThanComparable,
124+
within: LessThanComparable,
125+
at_start: LessThanComparable,
126+
before_start: LessThanComparable,
127+
) -> None:
128+
"""Test if a value is within the interval with no end."""
129+
interval_no_end = Interval(start=start, end=None)
130+
assert within in interval_no_end # within start
131+
assert at_start in interval_no_end # at start
132+
assert before_start not in interval_no_end # before start
133+
134+
135+
@pytest.mark.parametrize(
136+
"value",
137+
[
138+
50.0,
139+
10.0,
140+
-10.0,
141+
CustomComparable(50),
142+
CustomComparable(10),
143+
CustomComparable(-10),
144+
],
145+
)
146+
def test_interval_contains_unbound(value: LessThanComparable) -> None:
147+
"""Test if a value is within the interval with no bounds."""
148+
interval_no_bounds: Interval[LessThanComparable | None] = Interval(
149+
start=None, end=None
150+
)
151+
assert value in interval_no_bounds # any value within bounds

0 commit comments

Comments
 (0)