|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import itertools |
| 4 | +import random |
3 | 5 | import time |
| 6 | +from collections.abc import Callable |
4 | 7 | from typing import Any |
| 8 | +from weakref import WeakSet |
5 | 9 |
|
6 | 10 | import anyio |
7 | 11 | import anyio.lowlevel |
|
12 | 16 | pytestmark = pytest.mark.anyio |
13 | 17 |
|
14 | 18 |
|
| 19 | +random_args = random.sample( |
| 20 | + list( |
| 21 | + itertools.combinations( |
| 22 | + random.sample( |
| 23 | + list( |
| 24 | + itertools.combinations(list(itertools.chain(range(5), "abcde")), 4) |
| 25 | + ), |
| 26 | + 20, |
| 27 | + ), |
| 28 | + 2, |
| 29 | + ) |
| 30 | + ), |
| 31 | + 10, |
| 32 | +) |
| 33 | + |
| 34 | + |
15 | 35 | def test_memoize(cache): |
16 | 36 | count = 10 |
17 | 37 |
|
@@ -189,3 +209,131 @@ async def worker(num: int) -> int: |
189 | 209 | assert state["num"] > 0 |
190 | 210 |
|
191 | 211 | worker.wait() |
| 212 | + |
| 213 | + |
| 214 | +######################################## |
| 215 | + |
| 216 | + |
| 217 | +@pytest.mark.parametrize("memoize", [memo.memoize, memo.memoize_stampede]) |
| 218 | +def test_memoize_attrs(cache, memoize: Callable[..., memo.MemoizedDecorator]): |
| 219 | + def func() -> None: ... |
| 220 | + async def async_func() -> None: |
| 221 | + await anyio.lowlevel.checkpoint() |
| 222 | + |
| 223 | + wrapped = memoize(cache)(func) |
| 224 | + async_wrapped = memoize(cache)(async_func) |
| 225 | + |
| 226 | + assert wrapped is not func |
| 227 | + assert wrapped.__wrapped__ is func |
| 228 | + assert async_wrapped is not async_func |
| 229 | + assert async_wrapped.__wrapped__ is async_func |
| 230 | + |
| 231 | + if memoize is not memo.memoize_stampede: |
| 232 | + assert isinstance(wrapped, memo.Memoized) |
| 233 | + assert isinstance(async_wrapped, memo.AsyncMemoized) |
| 234 | + return |
| 235 | + |
| 236 | + assert isinstance(wrapped, memo.MemoizedStampede) |
| 237 | + assert isinstance(async_wrapped, memo.AsyncMemoizedStampede) |
| 238 | + assert isinstance(wrapped.futures, WeakSet) |
| 239 | + assert isinstance(async_wrapped.futures, WeakSet) |
| 240 | + |
| 241 | + |
| 242 | +@pytest.mark.parametrize("memoize", [memo.memoize, memo.memoize_stampede]) |
| 243 | +@pytest.mark.parametrize("is_async", [False, True]) |
| 244 | +@pytest.mark.parametrize(("memo_include", "memo_exclude"), random_args) |
| 245 | +async def test_memoize_include_and_exclude_args( |
| 246 | + cache, |
| 247 | + memoize: Callable[..., memo.MemoizedDecorator], |
| 248 | + is_async: bool, |
| 249 | + memo_include: tuple[str | int, ...], |
| 250 | + memo_exclude: tuple[str | int, ...], |
| 251 | +): |
| 252 | + include, exclude = set(memo_include), set(memo_exclude) |
| 253 | + cache.stats(enable=True) |
| 254 | + |
| 255 | + if is_async: |
| 256 | + |
| 257 | + @memoize(cache, include=include, exclude=exclude) |
| 258 | + async def func(*args: int, **kwargs: int) -> Any: |
| 259 | + await anyio.lowlevel.checkpoint() |
| 260 | + return sum(args) + sum(kwargs.values()) |
| 261 | + |
| 262 | + else: |
| 263 | + |
| 264 | + @memoize(cache, include=include, exclude=exclude) |
| 265 | + def func(*args: int, **kwargs: int) -> Any: |
| 266 | + return sum(args) + sum(kwargs.values()) |
| 267 | + |
| 268 | + args_all = list(range(5)) |
| 269 | + args_mask = set(range(5)).difference(include).difference(exclude) |
| 270 | + kwargs_all = {key: num for num, key in enumerate("abcde")} |
| 271 | + kwargs_mask = set("abcde").difference(include).difference(exclude) |
| 272 | + |
| 273 | + value = func(*args_all, **kwargs_all) |
| 274 | + if is_async: |
| 275 | + value = await value |
| 276 | + |
| 277 | + for key in exclude: |
| 278 | + if isinstance(key, int): |
| 279 | + args_all[key] += 1 |
| 280 | + else: |
| 281 | + kwargs_all[key] += 1 |
| 282 | + for index in args_mask: |
| 283 | + args_all[index] += 1 |
| 284 | + for key in kwargs_mask: |
| 285 | + kwargs_all[key] += 1 |
| 286 | + |
| 287 | + hits1, misses1 = cache.stats() |
| 288 | + alter = func(*args_all, **kwargs_all) |
| 289 | + if is_async: |
| 290 | + alter = await alter |
| 291 | + |
| 292 | + assert value == alter |
| 293 | + |
| 294 | + hits2, misses2 = cache.stats() |
| 295 | + |
| 296 | + assert hits2 == (hits1 + 1) |
| 297 | + assert misses2 == misses1 |
| 298 | + |
| 299 | + if memoize is memo.memoize_stampede: |
| 300 | + func.wait() # pyright: ignore[reportAttributeAccessIssue] |
| 301 | + |
| 302 | + |
| 303 | +@pytest.mark.parametrize("memoize", [memo.memoize, memo.memoize_stampede]) |
| 304 | +@pytest.mark.parametrize("is_async", [False, True]) |
| 305 | +async def test_cache_key( |
| 306 | + cache, memoize: Callable[..., memo.MemoizedDecorator], is_async: bool |
| 307 | +): |
| 308 | + if is_async: |
| 309 | + |
| 310 | + @memoize(cache) |
| 311 | + async def func(*args: int, **kwargs: int) -> Any: |
| 312 | + await anyio.lowlevel.checkpoint() |
| 313 | + return sum(args) + sum(kwargs.values()) |
| 314 | + |
| 315 | + else: |
| 316 | + |
| 317 | + @memoize(cache) |
| 318 | + def func(*args: int, **kwargs: int) -> Any: |
| 319 | + return sum(args) + sum(kwargs.values()) |
| 320 | + |
| 321 | + args = (1, 2, 3) |
| 322 | + kwargs = {"a": 4, "b": 5} |
| 323 | + value = func(*args, **kwargs) |
| 324 | + if is_async: |
| 325 | + value = await value |
| 326 | + |
| 327 | + key = func.cache_key(*args, **kwargs) |
| 328 | + |
| 329 | + assert key in cache |
| 330 | + container = cache[key] |
| 331 | + assert not container.default |
| 332 | + |
| 333 | + container_value = container.value |
| 334 | + if memoize is memo.memoize_stampede: |
| 335 | + container_value = container_value[0] |
| 336 | + assert container_value == value |
| 337 | + |
| 338 | + if memoize is memo.memoize_stampede: |
| 339 | + func.wait() # pyright: ignore[reportAttributeAccessIssue] |
0 commit comments