|
1 | 1 | # Licensed under the EUPL-1.2 or later. |
2 | 2 | # You may obtain a copy of the licence in all the official languages of the |
3 | 3 | # European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12 |
4 | | -# pylint: disable=too-many-lines |
| 4 | +# pylint: disable=too-many-lines, unidiomatic-typecheck |
5 | 5 |
|
6 | 6 | """The tests for the Stream.""" |
7 | 7 |
|
|
14 | 14 | import traceback |
15 | 15 | from collections import Counter |
16 | 16 | from collections.abc import Callable |
| 17 | +from decimal import Decimal |
| 18 | +from fractions import Fraction |
17 | 19 | from functools import partial |
18 | 20 | from numbers import Number, Real |
19 | 21 | from operator import add |
@@ -128,6 +130,7 @@ def assert_raises(exc: type[BaseException], fun: Callable[[], object]) -> None: |
128 | 130 | assert_raises(AssertionError, lambda: assert_raises(Exception, lambda: None)) |
129 | 131 | assert_raises(TypeError, lambda: hash(Stream(...))) |
130 | 132 | assert_raises(TypeError, lambda: hash(Stream([0, 1]))) |
| 133 | +assert_raises(StreamEmptyError, Stream([]).avg) |
131 | 134 |
|
132 | 135 | assert_raises(ValueError, lambda: sliding_window([], -1)) |
133 | 136 | assert_raises(ValueError, lambda: sliding_window((), 0)) |
@@ -225,6 +228,23 @@ def raise_exceptions(number: int) -> int: |
225 | 228 | == sum(range(20)) - 1 - 3 - 5 |
226 | 229 | ) |
227 | 230 |
|
| 231 | +average = sum(data_to_average := range(10, 100, 7)) / len(data_to_average) |
| 232 | +average_float: float = Stream(data_to_average).avg() |
| 233 | +assert type(average_float) is float |
| 234 | +assert average_float == average |
| 235 | +average_float = Stream(data_to_average).map(int).avg() |
| 236 | +assert type(average_float) is float |
| 237 | +assert average_float == average |
| 238 | +average_float = Stream(data_to_average).map(float).avg() |
| 239 | +assert type(average_float) is float |
| 240 | +assert average_float == average |
| 241 | +average_decimal: Decimal = Stream(data_to_average).map(Decimal).avg() |
| 242 | +assert type(average_decimal) is Decimal |
| 243 | +assert average_decimal == average |
| 244 | +average_fraction: Fraction = Stream(data_to_average).map(Fraction).avg() |
| 245 | +assert type(average_fraction) is Fraction |
| 246 | +assert average_fraction == average |
| 247 | + |
228 | 248 |
|
229 | 249 | errors: list[ValueError] = [] |
230 | 250 | assert ( |
|
0 commit comments