Skip to content

Commit 5c4af69

Browse files
committed
tests: memoize include, exclude, etc
1 parent 1b40c9a commit 5c4af69

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

ruff.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ ignore = [
108108
"D103",
109109
"D102",
110110
"PLR2004",
111+
"C",
112+
"FBT001",
111113
]
112114

113115
[format]

src/tests/test_memoize.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from __future__ import annotations
22

3+
import itertools
4+
import random
35
import time
6+
from collections.abc import Callable
47
from typing import Any
8+
from weakref import WeakSet
59

610
import anyio
711
import anyio.lowlevel
@@ -12,6 +16,22 @@
1216
pytestmark = pytest.mark.anyio
1317

1418

19+
random_args = random.sample(
20+
list(
21+
itertools.combinations(
22+
random.sample(
23+
list(
24+
itertools.combinations(list(itertools.chain(range(5), "abcde")), 4)
25+
),
26+
20,
27+
),
28+
2,
29+
)
30+
),
31+
10,
32+
)
33+
34+
1535
def test_memoize(cache):
1636
count = 10
1737

@@ -189,3 +209,131 @@ async def worker(num: int) -> int:
189209
assert state["num"] > 0
190210

191211
worker.wait()
212+
213+
214+
########################################
215+
216+
217+
@pytest.mark.parametrize("memoize", [memo.memoize, memo.memoize_stampede])
218+
def test_memoize_attrs(cache, memoize: Callable[..., memo.MemoizedDecorator]):
219+
def func() -> None: ...
220+
async def async_func() -> None:
221+
await anyio.lowlevel.checkpoint()
222+
223+
wrapped = memoize(cache)(func)
224+
async_wrapped = memoize(cache)(async_func)
225+
226+
assert wrapped is not func
227+
assert wrapped.__wrapped__ is func
228+
assert async_wrapped is not async_func
229+
assert async_wrapped.__wrapped__ is async_func
230+
231+
if memoize is not memo.memoize_stampede:
232+
assert isinstance(wrapped, memo.Memoized)
233+
assert isinstance(async_wrapped, memo.AsyncMemoized)
234+
return
235+
236+
assert isinstance(wrapped, memo.MemoizedStampede)
237+
assert isinstance(async_wrapped, memo.AsyncMemoizedStampede)
238+
assert isinstance(wrapped.futures, WeakSet)
239+
assert isinstance(async_wrapped.futures, WeakSet)
240+
241+
242+
@pytest.mark.parametrize("memoize", [memo.memoize, memo.memoize_stampede])
243+
@pytest.mark.parametrize("is_async", [False, True])
244+
@pytest.mark.parametrize(("memo_include", "memo_exclude"), random_args)
245+
async def test_memoize_include_and_exclude_args(
246+
cache,
247+
memoize: Callable[..., memo.MemoizedDecorator],
248+
is_async: bool,
249+
memo_include: tuple[str | int, ...],
250+
memo_exclude: tuple[str | int, ...],
251+
):
252+
include, exclude = set(memo_include), set(memo_exclude)
253+
cache.stats(enable=True)
254+
255+
if is_async:
256+
257+
@memoize(cache, include=include, exclude=exclude)
258+
async def func(*args: int, **kwargs: int) -> Any:
259+
await anyio.lowlevel.checkpoint()
260+
return sum(args) + sum(kwargs.values())
261+
262+
else:
263+
264+
@memoize(cache, include=include, exclude=exclude)
265+
def func(*args: int, **kwargs: int) -> Any:
266+
return sum(args) + sum(kwargs.values())
267+
268+
args_all = list(range(5))
269+
args_mask = set(range(5)).difference(include).difference(exclude)
270+
kwargs_all = {key: num for num, key in enumerate("abcde")}
271+
kwargs_mask = set("abcde").difference(include).difference(exclude)
272+
273+
value = func(*args_all, **kwargs_all)
274+
if is_async:
275+
value = await value
276+
277+
for key in exclude:
278+
if isinstance(key, int):
279+
args_all[key] += 1
280+
else:
281+
kwargs_all[key] += 1
282+
for index in args_mask:
283+
args_all[index] += 1
284+
for key in kwargs_mask:
285+
kwargs_all[key] += 1
286+
287+
hits1, misses1 = cache.stats()
288+
alter = func(*args_all, **kwargs_all)
289+
if is_async:
290+
alter = await alter
291+
292+
assert value == alter
293+
294+
hits2, misses2 = cache.stats()
295+
296+
assert hits2 == (hits1 + 1)
297+
assert misses2 == misses1
298+
299+
if memoize is memo.memoize_stampede:
300+
func.wait() # pyright: ignore[reportAttributeAccessIssue]
301+
302+
303+
@pytest.mark.parametrize("memoize", [memo.memoize, memo.memoize_stampede])
304+
@pytest.mark.parametrize("is_async", [False, True])
305+
async def test_cache_key(
306+
cache, memoize: Callable[..., memo.MemoizedDecorator], is_async: bool
307+
):
308+
if is_async:
309+
310+
@memoize(cache)
311+
async def func(*args: int, **kwargs: int) -> Any:
312+
await anyio.lowlevel.checkpoint()
313+
return sum(args) + sum(kwargs.values())
314+
315+
else:
316+
317+
@memoize(cache)
318+
def func(*args: int, **kwargs: int) -> Any:
319+
return sum(args) + sum(kwargs.values())
320+
321+
args = (1, 2, 3)
322+
kwargs = {"a": 4, "b": 5}
323+
value = func(*args, **kwargs)
324+
if is_async:
325+
value = await value
326+
327+
key = func.cache_key(*args, **kwargs)
328+
329+
assert key in cache
330+
container = cache[key]
331+
assert not container.default
332+
333+
container_value = container.value
334+
if memoize is memo.memoize_stampede:
335+
container_value = container_value[0]
336+
assert container_value == value
337+
338+
if memoize is memo.memoize_stampede:
339+
func.wait() # pyright: ignore[reportAttributeAccessIssue]

0 commit comments

Comments
 (0)