Skip to content

Commit c085a07

Browse files
feat: Add type annotations and refactor test utilities (#133)
Co-authored-by: Paillat <[email protected]>
1 parent 6ea3e05 commit c085a07

File tree

4 files changed

+93
-44
lines changed

4 files changed

+93
-44
lines changed

tests/test_find_util.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,69 +22,109 @@
2222
DEALINGS IN THE SOFTWARE.
2323
"""
2424

25+
from __future__ import annotations
26+
27+
from collections.abc import Callable, Iterable, Iterator
28+
from typing import Literal, TypeVar
29+
from typing_extensions import TypeIs
30+
2531
import pytest
2632
from discord.utils import find
2733

34+
T = TypeVar("T")
35+
2836

29-
def is_even(x):
37+
def is_even(x: int) -> bool:
3038
return x % 2 == 0
3139

3240

41+
def always_true(_: object) -> bool:
42+
return True
43+
44+
45+
def greater_than_3(x: int) -> bool:
46+
return x > 3
47+
48+
49+
def equals_1(x: int) -> TypeIs[Literal[1]]:
50+
return x == 1
51+
52+
53+
def equals_2(x: int) -> TypeIs[Literal[2]]:
54+
return x == 2
55+
56+
57+
def equals_b(c: str) -> TypeIs[Literal["b"]]:
58+
return c == "b"
59+
60+
61+
def equals_30(x: int) -> TypeIs[Literal[30]]:
62+
return x == 30
63+
64+
65+
def is_none_pred(x: object) -> TypeIs[Literal[None]]:
66+
return x is None
67+
68+
3369
@pytest.mark.parametrize(
3470
("seq", "predicate", "expected"),
3571
[
36-
([], lambda x: True, None),
37-
([1, 2, 3], lambda x: x > 3, None),
38-
([1, 2, 3], lambda x: x == 1, 1),
39-
([1, 2, 3], lambda x: x == 2, 2),
40-
("abc", lambda c: c == "b", "b"),
41-
((10, 20, 30), lambda x: x == 30, 30),
42-
([None, False, 0], lambda x: x is None, None),
72+
([], always_true, None),
73+
([1, 2, 3], greater_than_3, None),
74+
([1, 2, 3], equals_1, 1),
75+
([1, 2, 3], equals_2, 2),
76+
("abc", equals_b, "b"),
77+
((10, 20, 30), equals_30, 30),
78+
([None, False, 0], is_none_pred, None),
4379
([1, 2, 3, 4], is_even, 2),
4480
],
4581
)
46-
def test_find_basic_parametrized(seq, predicate, expected):
82+
def test_find_basic_parametrized(
83+
seq: Iterable[T],
84+
predicate: Callable[[T], object],
85+
expected: T | None,
86+
) -> None:
4787
result = find(predicate, seq)
4888
if expected is None:
4989
assert result is None
5090
else:
5191
assert result == expected
5292

5393

54-
def test_find_with_truthy_non_boolean_predicate():
55-
seq = [2, 4, 5, 6]
94+
def test_find_with_truthy_non_boolean_predicate() -> None:
95+
seq: list[int] = [2, 4, 5, 6]
5696
result = find(lambda x: x % 2, seq)
5797
assert result == 5
5898

5999

60-
def test_find_on_generator_and_stop_early():
61-
def bad_gen():
100+
def test_find_on_generator_and_stop_early() -> None:
101+
def bad_gen() -> Iterator[str]:
62102
yield "first"
63103
raise RuntimeError("should not be reached")
64104

65105
assert find(lambda x: x == "first", bad_gen()) == "first"
66106

67107

68-
def test_find_does_not_evaluate_rest():
69-
calls = []
108+
def test_find_does_not_evaluate_rest() -> None:
109+
calls: list[str] = []
70110

71-
def predicate(x):
111+
def predicate(x: str) -> bool:
72112
calls.append(x)
73113
return x == "stop"
74114

75-
seq = ["go", "stop", "later"]
115+
seq: list[str] = ["go", "stop", "later"]
76116
result = find(predicate, seq)
77117
assert result == "stop"
78118
assert calls == ["go", "stop"]
79119

80120

81-
def test_find_with_set_returns_first_iterated_element():
82-
data = {"a", "b", "c"}
121+
def test_find_with_set_returns_first_iterated_element() -> None:
122+
data: set[str] = {"a", "b", "c"}
83123
result = find(lambda x: x in data, data)
84124
assert result in data
85125

86126

87-
def test_find_none_predicate():
88-
seq = [42, 43, 44]
127+
def test_find_none_predicate() -> None:
128+
seq: list[int] = [42, 43, 44]
89129
result = find(lambda x: True, seq)
90130
assert result == 42

tests/test_format_dt.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,21 @@
2525
import datetime
2626
import random
2727
import pytest
28-
from discord.utils import format_dt
28+
from discord.utils.public import format_dt, TimestampStyle
2929

3030
# Fix seed so that time tests are reproducible
3131
random.seed(42)
3232

33-
ALL_STYLES = ["t", "T", "d", "D", "f", "F", "R", None]
33+
ALL_STYLES = [
34+
"t",
35+
"T",
36+
"d",
37+
"D",
38+
"f",
39+
"F",
40+
"R",
41+
None,
42+
]
3443

3544
DATETIME_CASES = [
3645
(datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc), 0),
@@ -41,7 +50,7 @@
4150
]
4251

4352

44-
def random_time():
53+
def random_time() -> datetime.time:
4554
return datetime.time(
4655
random.randint(0, 23),
4756
random.randint(0, 59),
@@ -51,7 +60,11 @@ def random_time():
5160

5261
@pytest.mark.parametrize(("dt", "expected_ts"), DATETIME_CASES)
5362
@pytest.mark.parametrize("style", ALL_STYLES)
54-
def test_format_dt_formats_datetime(dt, expected_ts, style):
63+
def test_format_dt_formats_datetime(
64+
dt: datetime.datetime,
65+
expected_ts: int,
66+
style: TimestampStyle | None,
67+
) -> None:
5568
if style is None:
5669
expected = f"<t:{expected_ts}>"
5770
else:
@@ -61,7 +74,9 @@ def test_format_dt_formats_datetime(dt, expected_ts, style):
6174

6275

6376
@pytest.mark.parametrize("style", ALL_STYLES)
64-
def test_format_dt_formats_time_equivalence(style):
77+
def test_format_dt_formats_time_equivalence(
78+
style: TimestampStyle | None,
79+
) -> None:
6580
tm = random_time()
6681
today = datetime.datetime.now().date()
6782
result_time = format_dt(tm, style=style)

tests/test_markdown_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,10 @@
2323
"""
2424

2525
from discord.utils import (
26-
oauth_url,
27-
snowflake_time,
28-
find,
29-
get_or_fetch,
30-
utcnow,
31-
remove_markdown,
32-
escape_markdown,
3326
escape_mentions,
3427
raw_mentions,
3528
raw_channel_mentions,
3629
raw_role_mentions,
37-
format_dt,
38-
generate_snowflake,
39-
basic_autocomplete,
4030
)
4131

4232

tests/test_snowflake_datetime.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
import datetime
2626
import pytest
2727

28-
from discord.utils import generate_snowflake, snowflake_time, DISCORD_EPOCH
28+
from discord.utils import (
29+
DISCORD_EPOCH,
30+
generate_snowflake,
31+
snowflake_time,
32+
)
2933

3034
UTC = datetime.timezone.utc
3135

@@ -39,40 +43,40 @@
3943

4044

4145
@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
42-
def test_generate_snowflake_realistic(dt, expected_ms):
46+
def test_generate_snowflake_realistic(dt: datetime.datetime, expected_ms: int) -> None:
4347
sf = generate_snowflake(dt, mode="realistic")
4448
assert (sf >> 22) == expected_ms
4549
assert (sf & ((1 << 22) - 1)) == 0x3FFFFF
4650

4751

4852
@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
49-
def test_generate_snowflake_boundary_low(dt, expected_ms):
53+
def test_generate_snowflake_boundary_low(dt: datetime.datetime, expected_ms: int) -> None:
5054
sf = generate_snowflake(dt, mode="boundary", high=False)
5155
assert (sf >> 22) == expected_ms
5256
assert (sf & ((1 << 22) - 1)) == 0
5357

5458

5559
@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
56-
def test_generate_snowflake_boundary_high(dt, expected_ms):
60+
def test_generate_snowflake_boundary_high(dt: datetime.datetime, expected_ms: int) -> None:
5761
sf = generate_snowflake(dt, mode="boundary", high=True)
5862
assert (sf >> 22) == expected_ms
5963
assert (sf & ((1 << 22) - 1)) == (2**22 - 1)
6064

6165

6266
@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
63-
def test_snowflake_time_roundtrip_boundary(dt, expected_ms):
67+
def test_snowflake_time_roundtrip_boundary(dt: datetime.datetime, _expected_ms: int) -> None:
6468
sf_low = generate_snowflake(dt, mode="boundary", high=False)
6569
sf_high = generate_snowflake(dt, mode="boundary", high=True)
6670
assert snowflake_time(sf_low) == dt
6771
assert snowflake_time(sf_high) == dt
6872

6973

7074
@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
71-
def test_snowflake_time_roundtrip_realistic(dt, expected_ms):
75+
def test_snowflake_time_roundtrip_realistic(dt: datetime.datetime, _expected_ms: int) -> None:
7276
sf = generate_snowflake(dt, mode="realistic")
7377
assert snowflake_time(sf) == dt
7478

7579

76-
def test_generate_snowflake_invalid_mode():
80+
def test_generate_snowflake_invalid_mode() -> None:
7781
with pytest.raises(ValueError, match="Invalid mode 'nope'. Must be 'realistic' or 'boundary'"):
78-
generate_snowflake(datetime.datetime.now(tz=UTC), mode="nope")
82+
generate_snowflake(datetime.datetime.now(tz=UTC), mode="nope") # pyright: ignore[reportArgumentType]

0 commit comments

Comments
 (0)