Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jupyter_server/base/call_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def get(cls, name: str) -> Any:
The value associated with the named variable for this call context
"""
name_value_map = CallContext._get_map()

if name in name_value_map:
return name_value_map[name]
return None # TODO: should this raise `LookupError` (or a custom error derived from said)
Expand All @@ -61,8 +60,9 @@ def set(cls, name: str, value: Any) -> None:
-------
None
"""
name_value_map = CallContext._get_map()
name_value_map = CallContext._get_map().copy()
name_value_map[name] = value
CallContext._name_value_map.set(name_value_map)

@classmethod
def context_variable_names(cls) -> list[str]:
Expand Down
37 changes: 37 additions & 0 deletions tests/base/test_call_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from contextvars import copy_context

from jupyter_server import CallContext
from jupyter_server.auth.utils import get_anonymous_username
Expand Down Expand Up @@ -107,3 +108,39 @@ async def context2():
# Assert that THIS context doesn't have any variables defined.
names = CallContext.context_variable_names()
assert len(names) == 0


async def test_callcontext_with_copy_context_run():
"""
Test scenario:
- The upper layer uses copy_context().run()
- Multiple contexts concurrently modify CallContext
- Verify that no context pollution occurs
"""

async def context_task(name, value, delay):
"""Coroutine task that modifies CallContext and validates its own values"""
await asyncio.sleep(delay)
CallContext.set(name, value)
# Sleep again to simulate interleaving execution
await asyncio.sleep(0.1)
assert CallContext.get(name) == value, f"{name} was polluted"
# Ensure only the variable written by this context exists
keys = CallContext.context_variable_names()
assert name in keys
assert len(keys) == 1

# Initialize a variable in the main context
CallContext.set("foo", "bar3")

# Create two independent copy_context instances
ctx1 = copy_context()
ctx2 = copy_context()

# Run coroutines in their respective contexts
fut1 = asyncio.create_task(ctx1.run(lambda: context_task("foo", "bar1", 0.0)))
fut2 = asyncio.create_task(ctx2.run(lambda: context_task("foo", "bar2", 0.05)))
await asyncio.gather(fut1, fut2)

# The main context should remain unaffected (still is empty)
assert CallContext.get("foo") == "bar3"
Loading