Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion tests/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Licensed under the EUPL-1.2 or later.
# You may obtain a copy of the licence in all the official languages of the
# European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12
# pylint: disable=too-many-lines
# pylint: disable=too-many-lines, unidiomatic-typecheck

"""The tests for the Stream."""

Expand All @@ -14,6 +14,8 @@
import traceback
from collections import Counter
from collections.abc import Callable
from decimal import Decimal
from fractions import Fraction
from functools import partial
from numbers import Number, Real
from operator import add
Expand Down Expand Up @@ -128,6 +130,7 @@ def assert_raises(exc: type[BaseException], fun: Callable[[], object]) -> None:
assert_raises(AssertionError, lambda: assert_raises(Exception, lambda: None))
assert_raises(TypeError, lambda: hash(Stream(...)))
assert_raises(TypeError, lambda: hash(Stream([0, 1])))
assert_raises(StreamEmptyError, Stream([]).avg)

assert_raises(ValueError, lambda: sliding_window([], -1))
assert_raises(ValueError, lambda: sliding_window((), 0))
Expand Down Expand Up @@ -225,6 +228,23 @@ def raise_exceptions(number: int) -> int:
== sum(range(20)) - 1 - 3 - 5
)

average = sum(data_to_average := range(10, 100, 7)) / len(data_to_average)
average_float: float = Stream(data_to_average).avg()
assert type(average_float) is float
assert average_float == average
average_float = Stream(data_to_average).map(int).avg()
assert type(average_float) is float
assert average_float == average
average_float = Stream(data_to_average).map(float).avg()
assert type(average_float) is float
assert average_float == average
average_decimal: Decimal = Stream(data_to_average).map(Decimal).avg()
assert type(average_decimal) is Decimal
assert average_decimal == average
average_fraction: Fraction = Stream(data_to_average).map(Fraction).avg()
assert type(average_fraction) is Fraction
assert average_fraction == average


errors: list[ValueError] = []
assert (
Expand Down
13 changes: 12 additions & 1 deletion typed_stream/_impl/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from os import PathLike
from typing import Generic, Protocol, TypeAlias, TypeGuard, TypeVar

from ._typing import override
from ._typing import Self, override

__all__ = (
"ClassWithCleanUp",
Expand All @@ -22,6 +22,7 @@
"PrettyRepr",
"StarCallable",
"SupportsAdd",
"SupportsAverage",
"SupportsComparison",
"SupportsGreaterThan",
"SupportsLessThan",
Expand Down Expand Up @@ -85,6 +86,16 @@ def __add__(self: T, other: T) -> T:
"""Add another instance of the same type to self."""


class SupportsAverage(Protocol[T_co]):
"""A type that supports calculating an average if in an Iterable."""

def __add__(self, other: Self) -> Self:
"""Add another instance of the same type to self."""

def __truediv__(self, other: int) -> T_co:
"""Divide by an int."""


class Closeable(abc.ABC):
"""Class that can be closed."""

Expand Down
25 changes: 25 additions & 0 deletions typed_stream/_impl/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ._types import (
StarCallable,
SupportsAdd,
SupportsAverage,
SupportsComparison,
TypeGuardingCallable,
)
Expand Down Expand Up @@ -315,6 +316,30 @@ def all(self) -> bool:
"""
return self._finish(all(self._data), close_source=True)

def avg(self: Stream[SupportsAverage[K]]) -> K:
"""Calculate the average of the elements in self.

Raises StreamEmptyError if the stream is empty.

>>> Stream(data := range(1, 5)).avg()
2.5
>>> Stream(data).map(int).avg()
2.5
>>> Stream(data).map(float).avg()
2.5
>>> Stream(data).map(__import__("decimal").Decimal).avg()
Decimal('2.5')
>>> Stream(data).map(__import__("fractions").Fraction).avg()
Fraction(5, 2)
>>> Stream([]).avg()
Traceback (most recent call last):
...
typed_stream.exceptions.StreamEmptyError
"""
counter = itertools.count()

return self.peek(lambda _: next(counter)).sum() / next(counter)

@overload
def catch(
self,
Expand Down