Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/frequenz/core/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# License: MIT
# Copyright © 2022 Frequenz Energy-as-a-Service GmbH

"""Data structures that contain collections of values or objects."""

from dataclasses import dataclass
from typing import Generic, Protocol, Self, TypeVar, cast


class LessThanComparable(Protocol):
"""A protocol that requires the `__lt__` method to compare values."""

def __lt__(self, other: Self, /) -> bool:
"""Return whether self is less than other."""


LessThanComparableOrNoneT = TypeVar(
"LessThanComparableOrNoneT", bound=LessThanComparable | None
)
"""Type variable for a value that a `LessThanComparable` or `None`."""


@dataclass(frozen=True)
class Interval(Generic[LessThanComparableOrNoneT]):
"""An interval to test if a value is within its limits.

The `start` and `end` are inclusive, meaning that the `start` and `end` limites are
included in the range when checking if a value is contained by the interval.

If the `start` or `end` is `None`, it means that the interval is unbounded in that
direction.

If `start` is bigger than `end`, a `ValueError` is raised.

The type stored in the interval must be comparable, meaning that it must implement
the `__lt__` method to be able to compare values.
"""

start: LessThanComparableOrNoneT
"""The start of the interval."""

end: LessThanComparableOrNoneT
"""The end of the interval."""

def __post_init__(self) -> None:
"""Check if the start is less than or equal to the end."""
if self.start is None or self.end is None:
return
start = cast(LessThanComparable, self.start)
end = cast(LessThanComparable, self.end)
if start > end:
raise ValueError(
f"The start ({self.start}) can't be bigger than end ({self.end})"
)

def __contains__(self, item: LessThanComparableOrNoneT) -> bool:
"""
Check if the value is within the range of the container.

Args:
item: The value to check.

Returns:
bool: True if value is within the range, otherwise False.
"""
if item is None:
return False
casted_item = cast(LessThanComparable, item)

if self.start is None and self.end is None:
return True
if self.start is None:
start = cast(LessThanComparable, self.end)
return not casted_item > start
if self.end is None:
return not self.start > item
# mypy seems to get confused here, not being able to narrow start and end to
# just LessThanComparable, complaining with:
# error: Unsupported left operand type for <= (some union)
# But we know if they are not None, they should be LessThanComparable, and
# actually mypy is being able to figure it out in the lines above, just not in
# this one, so it should be safe to cast.
return not (
casted_item < cast(LessThanComparable, self.start)
or casted_item > cast(LessThanComparable, self.end)
)
151 changes: 151 additions & 0 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# License: MIT
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH

"""Tests for the collections module."""


from typing import Self

import pytest

from frequenz.core.collections import Interval, LessThanComparable


class CustomComparable:
"""A custom comparable class."""

def __init__(self, value: int) -> None:
"""Initialize this instance."""
self.value = value

def __lt__(self, other: Self) -> bool:
"""Return whether this instance is less than other."""
return self.value < other.value

def __eq__(self, other: object) -> bool:
"""Return whether this instance is equal to other."""
if not isinstance(other, CustomComparable):
return False
return self.value == other.value

def __repr__(self) -> str:
"""Return a string representation of this instance."""
return str(self.value)


@pytest.mark.parametrize(
"start, end",
[
(10.0, -100.0),
(CustomComparable(10), CustomComparable(-100)),
],
)
def test_invalid_range(start: LessThanComparable, end: LessThanComparable) -> None:
"""Test if the interval has an invalid range."""
with pytest.raises(
ValueError,
match=rf"The start \({start}\) can't be bigger than end \({end}\)",
):
Interval(start, end)


@pytest.mark.parametrize(
"start, end, within, at_start, at_end, before_start, after_end",
[
(10.0, 100.0, 50.0, 10.0, 100.0, 9.0, 101.0),
(
CustomComparable(10),
CustomComparable(100),
CustomComparable(50),
CustomComparable(10),
CustomComparable(100),
CustomComparable(9),
CustomComparable(101),
),
],
)
def test_interval_contains( # pylint: disable=too-many-arguments
start: LessThanComparable,
end: LessThanComparable,
within: LessThanComparable,
at_start: LessThanComparable,
at_end: LessThanComparable,
before_start: LessThanComparable,
after_end: LessThanComparable,
) -> None:
"""Test if a value is within the interval."""
interval = Interval(start=start, end=end)
assert within in interval # within
assert at_start in interval # at start
assert at_end in interval # at end
assert before_start not in interval # before start
assert after_end not in interval # after end


@pytest.mark.parametrize(
"end, within, at_end, after_end",
[
(100.0, 50.0, 100.0, 101.0),
(
CustomComparable(100),
CustomComparable(50),
CustomComparable(100),
CustomComparable(101),
),
],
)
def test_interval_contains_no_start(
end: LessThanComparable,
within: LessThanComparable,
at_end: LessThanComparable,
after_end: LessThanComparable,
) -> None:
"""Test if a value is within the interval with no start."""
interval_no_start = Interval(start=None, end=end)
assert within in interval_no_start # within end
assert at_end in interval_no_start # at end
assert after_end not in interval_no_start # after end


@pytest.mark.parametrize(
"start, within, at_start, before_start",
[
(10.0, 50.0, 10.0, 9.0),
(
CustomComparable(10),
CustomComparable(50),
CustomComparable(10),
CustomComparable(9),
),
],
)
def test_interval_contains_no_end(
start: LessThanComparable,
within: LessThanComparable,
at_start: LessThanComparable,
before_start: LessThanComparable,
) -> None:
"""Test if a value is within the interval with no end."""
interval_no_end = Interval(start=start, end=None)
assert within in interval_no_end # within start
assert at_start in interval_no_end # at start
assert before_start not in interval_no_end # before start


@pytest.mark.parametrize(
"value",
[
50.0,
10.0,
-10.0,
CustomComparable(50),
CustomComparable(10),
CustomComparable(-10),
],
)
def test_interval_contains_unbound(value: LessThanComparable) -> None:
"""Test if a value is within the interval with no bounds."""
interval_no_bounds: Interval[LessThanComparable | None] = Interval(
start=None, end=None
)
assert value in interval_no_bounds # any value within bounds