Skip to content

Commit b013105

Browse files
committed
Validate Interval start/end
The start can't be bigger than the end and we raise a `ValueError` in that case. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 17183cb commit b013105

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/frequenz/core/collections.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ class Interval(Generic[LessThanComparableOrNoneT]):
2727
The `start` and `end` are inclusive, meaning that the `start` and `end` limites are
2828
included in the range when checking if a value is contained by the interval.
2929
30+
If the `start` or `end` is `None`, it means that the interval is unbounded in that
31+
direction.
32+
33+
If `start` is bigger than `end`, a `ValueError` is raised.
34+
3035
The type stored in the interval must be comparable, meaning that it must implement
3136
the `__lt__` method to be able to compare values.
3237
"""
@@ -37,6 +42,17 @@ class Interval(Generic[LessThanComparableOrNoneT]):
3742
end: LessThanComparableOrNoneT
3843
"""The end of the interval."""
3944

45+
def __post_init__(self) -> None:
46+
"""Check if the start is less than or equal to the end."""
47+
if self.start is None or self.end is None:
48+
return
49+
start = cast(LessThanComparable, self.start)
50+
end = cast(LessThanComparable, self.end)
51+
if start > end:
52+
raise ValueError(
53+
f"The start ({self.start}) can't be bigger than end ({self.end})"
54+
)
55+
4056
def __contains__(self, item: LessThanComparableOrNoneT) -> bool:
4157
"""
4258
Check if the value is within the range of the container.

tests/test_collections.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ def __repr__(self) -> str:
3333
return str(self.value)
3434

3535

36+
@pytest.mark.parametrize(
37+
"start, end",
38+
[
39+
(10.0, -100.0),
40+
(CustomComparable(10), CustomComparable(-100)),
41+
],
42+
)
43+
def test_invalid_range(start: LessThanComparable, end: LessThanComparable) -> None:
44+
"""Test if the interval has an invalid range."""
45+
with pytest.raises(
46+
ValueError,
47+
match=rf"The start \({start}\) can't be bigger than end \({end}\)",
48+
):
49+
Interval(start, end)
50+
51+
3652
@pytest.mark.parametrize(
3753
"start, end, within, at_start, at_end, before_start, after_end",
3854
[

0 commit comments

Comments
 (0)