diff --git a/jupyter_server/base/call_context.py b/jupyter_server/base/call_context.py index 4e80be8a7d..13f6a762ed 100644 --- a/jupyter_server/base/call_context.py +++ b/jupyter_server/base/call_context.py @@ -41,7 +41,6 @@ def get(cls, name: str) -> Any: The value associated with the named variable for this call context """ name_value_map = CallContext._get_map() - if name in name_value_map: return name_value_map[name] return None # TODO: should this raise `LookupError` (or a custom error derived from said) @@ -61,8 +60,9 @@ def set(cls, name: str, value: Any) -> None: ------- None """ - name_value_map = CallContext._get_map() + name_value_map = CallContext._get_map().copy() name_value_map[name] = value + CallContext._name_value_map.set(name_value_map) @classmethod def context_variable_names(cls) -> list[str]: diff --git a/tests/base/test_call_context.py b/tests/base/test_call_context.py index 1c12338d61..f3e48522f5 100644 --- a/tests/base/test_call_context.py +++ b/tests/base/test_call_context.py @@ -1,4 +1,5 @@ import asyncio +from contextvars import copy_context from jupyter_server import CallContext from jupyter_server.auth.utils import get_anonymous_username @@ -107,3 +108,39 @@ async def context2(): # Assert that THIS context doesn't have any variables defined. names = CallContext.context_variable_names() assert len(names) == 0 + + +async def test_callcontext_with_copy_context_run(): + """ + Test scenario: + - The upper layer uses copy_context().run() + - Multiple contexts concurrently modify CallContext + - Verify that no context pollution occurs + """ + + async def context_task(name, value, delay): + """Coroutine task that modifies CallContext and validates its own values""" + await asyncio.sleep(delay) + CallContext.set(name, value) + # Sleep again to simulate interleaving execution + await asyncio.sleep(0.1) + assert CallContext.get(name) == value, f"{name} was polluted" + # Ensure only the variable written by this context exists + keys = CallContext.context_variable_names() + assert name in keys + assert len(keys) == 1 + + # Initialize a variable in the main context + CallContext.set("foo", "bar3") + + # Create two independent copy_context instances + ctx1 = copy_context() + ctx2 = copy_context() + + # Run coroutines in their respective contexts + fut1 = asyncio.create_task(ctx1.run(lambda: context_task("foo", "bar1", 0.0))) + fut2 = asyncio.create_task(ctx2.run(lambda: context_task("foo", "bar2", 0.05))) + await asyncio.gather(fut1, fut2) + + # The main context should remain unaffected (still is empty) + assert CallContext.get("foo") == "bar3"