diff --git a/tests/__main__.py b/tests/__main__.py index 43cdfb2..053099f 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -1,7 +1,7 @@ # Licensed under the EUPL-1.2 or later. # You may obtain a copy of the licence in all the official languages of the # European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12 -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines, unidiomatic-typecheck """The tests for the Stream.""" @@ -14,6 +14,8 @@ import traceback from collections import Counter from collections.abc import Callable +from decimal import Decimal +from fractions import Fraction from functools import partial from numbers import Number, Real from operator import add @@ -128,6 +130,7 @@ def assert_raises(exc: type[BaseException], fun: Callable[[], object]) -> None: assert_raises(AssertionError, lambda: assert_raises(Exception, lambda: None)) assert_raises(TypeError, lambda: hash(Stream(...))) assert_raises(TypeError, lambda: hash(Stream([0, 1]))) +assert_raises(StreamEmptyError, Stream([]).avg) assert_raises(ValueError, lambda: sliding_window([], -1)) assert_raises(ValueError, lambda: sliding_window((), 0)) @@ -225,6 +228,23 @@ def raise_exceptions(number: int) -> int: == sum(range(20)) - 1 - 3 - 5 ) +average = sum(data_to_average := range(10, 100, 7)) / len(data_to_average) +average_float: float = Stream(data_to_average).avg() +assert type(average_float) is float +assert average_float == average +average_float = Stream(data_to_average).map(int).avg() +assert type(average_float) is float +assert average_float == average +average_float = Stream(data_to_average).map(float).avg() +assert type(average_float) is float +assert average_float == average +average_decimal: Decimal = Stream(data_to_average).map(Decimal).avg() +assert type(average_decimal) is Decimal +assert average_decimal == average +average_fraction: Fraction = Stream(data_to_average).map(Fraction).avg() +assert type(average_fraction) is Fraction +assert average_fraction == average + errors: list[ValueError] = [] assert ( diff --git a/typed_stream/_impl/_types.py b/typed_stream/_impl/_types.py index 4209909..85411d4 100644 --- a/typed_stream/_impl/_types.py +++ b/typed_stream/_impl/_types.py @@ -12,7 +12,7 @@ from os import PathLike from typing import Generic, Protocol, TypeAlias, TypeGuard, TypeVar -from ._typing import override +from ._typing import Self, override __all__ = ( "ClassWithCleanUp", @@ -22,6 +22,7 @@ "PrettyRepr", "StarCallable", "SupportsAdd", + "SupportsAverage", "SupportsComparison", "SupportsGreaterThan", "SupportsLessThan", @@ -85,6 +86,16 @@ def __add__(self: T, other: T) -> T: """Add another instance of the same type to self.""" +class SupportsAverage(Protocol[T_co]): + """A type that supports calculating an average if in an Iterable.""" + + def __add__(self, other: Self) -> Self: + """Add another instance of the same type to self.""" + + def __truediv__(self, other: int) -> T_co: + """Divide by an int.""" + + class Closeable(abc.ABC): """Class that can be closed.""" diff --git a/typed_stream/_impl/stream.py b/typed_stream/_impl/stream.py index 239cdc0..6b42f53 100644 --- a/typed_stream/_impl/stream.py +++ b/typed_stream/_impl/stream.py @@ -30,6 +30,7 @@ from ._types import ( StarCallable, SupportsAdd, + SupportsAverage, SupportsComparison, TypeGuardingCallable, ) @@ -315,6 +316,30 @@ def all(self) -> bool: """ return self._finish(all(self._data), close_source=True) + def avg(self: Stream[SupportsAverage[K]]) -> K: + """Calculate the average of the elements in self. + + Raises StreamEmptyError if the stream is empty. + + >>> Stream(data := range(1, 5)).avg() + 2.5 + >>> Stream(data).map(int).avg() + 2.5 + >>> Stream(data).map(float).avg() + 2.5 + >>> Stream(data).map(__import__("decimal").Decimal).avg() + Decimal('2.5') + >>> Stream(data).map(__import__("fractions").Fraction).avg() + Fraction(5, 2) + >>> Stream([]).avg() + Traceback (most recent call last): + ... + typed_stream.exceptions.StreamEmptyError + """ + counter = itertools.count() + + return self.peek(lambda _: next(counter)).sum() / next(counter) + @overload def catch( self,