Skip to content

Commit 4447207

Browse files
committed
add Stream.avg to calculate the average
1 parent cb6185f commit 4447207

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

tests/__main__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Licensed under the EUPL-1.2 or later.
22
# You may obtain a copy of the licence in all the official languages of the
33
# European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12
4-
# pylint: disable=too-many-lines
4+
# pylint: disable=too-many-lines, unidiomatic-typecheck
55

66
"""The tests for the Stream."""
77

@@ -14,6 +14,8 @@
1414
import traceback
1515
from collections import Counter
1616
from collections.abc import Callable
17+
from decimal import Decimal
18+
from fractions import Fraction
1719
from functools import partial
1820
from numbers import Number, Real
1921
from operator import add
@@ -128,6 +130,7 @@ def assert_raises(exc: type[BaseException], fun: Callable[[], object]) -> None:
128130
assert_raises(AssertionError, lambda: assert_raises(Exception, lambda: None))
129131
assert_raises(TypeError, lambda: hash(Stream(...)))
130132
assert_raises(TypeError, lambda: hash(Stream([0, 1])))
133+
assert_raises(StreamEmptyError, Stream([]).avg)
131134

132135
assert_raises(ValueError, lambda: sliding_window([], -1))
133136
assert_raises(ValueError, lambda: sliding_window((), 0))
@@ -225,6 +228,23 @@ def raise_exceptions(number: int) -> int:
225228
== sum(range(20)) - 1 - 3 - 5
226229
)
227230

231+
average = sum(data_to_average := range(10, 100, 7)) / len(data_to_average)
232+
average_float: float = Stream(data_to_average).avg()
233+
assert type(average_float) is float
234+
assert average_float == average
235+
average_float = Stream(data_to_average).map(int).avg()
236+
assert type(average_float) is float
237+
assert average_float == average
238+
average_float = Stream(data_to_average).map(float).avg()
239+
assert type(average_float) is float
240+
assert average_float == average
241+
average_decimal: Decimal = Stream(data_to_average).map(Decimal).avg()
242+
assert type(average_decimal) is Decimal
243+
assert average_decimal == average
244+
average_fraction: Fraction = Stream(data_to_average).map(Fraction).avg()
245+
assert type(average_fraction) is Fraction
246+
assert average_fraction == average
247+
228248

229249
errors: list[ValueError] = []
230250
assert (

typed_stream/_impl/_types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from os import PathLike
1313
from typing import Generic, Protocol, TypeAlias, TypeGuard, TypeVar
1414

15-
from ._typing import override
15+
from ._typing import Self, override
1616

1717
__all__ = (
1818
"ClassWithCleanUp",
@@ -22,6 +22,7 @@
2222
"PrettyRepr",
2323
"StarCallable",
2424
"SupportsAdd",
25+
"SupportsAverage",
2526
"SupportsComparison",
2627
"SupportsGreaterThan",
2728
"SupportsLessThan",
@@ -85,6 +86,16 @@ def __add__(self: T, other: T) -> T:
8586
"""Add another instance of the same type to self."""
8687

8788

89+
class SupportsAverage(Protocol[T_co]):
90+
"""A type that supports calculating an average if in an Iterable."""
91+
92+
def __add__(self, other: Self) -> Self:
93+
"""Add another instance of the same type to self."""
94+
95+
def __truediv__(self, other: int) -> T_co:
96+
"""Divide by an int."""
97+
98+
8899
class Closeable(abc.ABC):
89100
"""Class that can be closed."""
90101

typed_stream/_impl/stream.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ._types import (
3131
StarCallable,
3232
SupportsAdd,
33+
SupportsAverage,
3334
SupportsComparison,
3435
TypeGuardingCallable,
3536
)
@@ -315,6 +316,30 @@ def all(self) -> bool:
315316
"""
316317
return self._finish(all(self._data), close_source=True)
317318

319+
def avg(self: Stream[SupportsAverage[K]]) -> K:
320+
"""Calculate the average of the elements in self.
321+
322+
Raises StreamEmptyError if the stream is empty.
323+
324+
>>> Stream(data := range(1, 5)).avg()
325+
2.5
326+
>>> Stream(data).map(int).avg()
327+
2.5
328+
>>> Stream(data).map(float).avg()
329+
2.5
330+
>>> Stream(data).map(__import__("decimal").Decimal).avg()
331+
Decimal('2.5')
332+
>>> Stream(data).map(__import__("fractions").Fraction).avg()
333+
Fraction(5, 2)
334+
>>> Stream([]).avg()
335+
Traceback (most recent call last):
336+
...
337+
typed_stream.exceptions.StreamEmptyError
338+
"""
339+
counter = itertools.count()
340+
341+
return self.peek(lambda _: next(counter)).sum() / next(counter)
342+
318343
@overload
319344
def catch(
320345
self,

0 commit comments

Comments
 (0)