diff --git a/asgiref/local.py b/asgiref/local.py index 19b395f1..845313a8 100644 --- a/asgiref/local.py +++ b/asgiref/local.py @@ -2,38 +2,37 @@ import contextlib import contextvars import threading -from typing import Any, Union +from typing import Any, Dict, Union class _CVar: """Storage utility for Local.""" def __init__(self) -> None: - self._data: dict[str, contextvars.ContextVar[Any]] = {} + self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar( + "asgiref.local" + ) - def __getattr__(self, key: str) -> Any: + def __getattr__(self, key): + storage_object = self._data.get({}) try: - var = self._data[key] + return storage_object[key] except KeyError: raise AttributeError(f"{self!r} object has no attribute {key!r}") - try: - return var.get() - except LookupError: - raise AttributeError(f"{self!r} object has no attribute {key!r}") - def __setattr__(self, key: str, value: Any) -> None: if key == "_data": return super().__setattr__(key, value) - var = self._data.get(key) - if var is None: - self._data[key] = var = contextvars.ContextVar(key) - var.set(value) + storage_object = self._data.get({}).copy() + storage_object[key] = value + self._data.set(storage_object) def __delattr__(self, key: str) -> None: - if key in self._data: - del self._data[key] + storage_object = self._data.get({}).copy() + if key in storage_object: + del storage_object[key] + self._data.set(storage_object) else: raise AttributeError(f"{self!r} object has no attribute {key!r}") diff --git a/tests/test_local.py b/tests/test_local.py index cdcbd280..77e57e89 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -375,3 +375,19 @@ async def _test() -> None: # Changes should not leak to the caller assert test_local.value == 0 + + +@pytest.mark.asyncio +async def test_deletion() -> None: + """Check visibility with asyncio tasks.""" + test_local = Local() + test_local.value = 123 + + async def _test() -> None: + # Local is inherited when changing task + assert test_local.value == 123 + del test_local.value + assert not hasattr(test_local, "value") + + await asyncio.create_task(_test()) + assert test_local.value == 123