Skip to content

Commit 987ebdd

Browse files
dualcchengcong1
andauthored
fix context pollution (#1561)
Co-authored-by: chengcong1 <[email protected]>
1 parent 8f99062 commit 987ebdd

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

jupyter_server/base/call_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def get(cls, name: str) -> Any:
4141
The value associated with the named variable for this call context
4242
"""
4343
name_value_map = CallContext._get_map()
44-
4544
if name in name_value_map:
4645
return name_value_map[name]
4746
return None # TODO: should this raise `LookupError` (or a custom error derived from said)
@@ -61,8 +60,9 @@ def set(cls, name: str, value: Any) -> None:
6160
-------
6261
None
6362
"""
64-
name_value_map = CallContext._get_map()
63+
name_value_map = CallContext._get_map().copy()
6564
name_value_map[name] = value
65+
CallContext._name_value_map.set(name_value_map)
6666

6767
@classmethod
6868
def context_variable_names(cls) -> list[str]:

tests/base/test_call_context.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from contextvars import copy_context
23

34
from jupyter_server import CallContext
45
from jupyter_server.auth.utils import get_anonymous_username
@@ -107,3 +108,39 @@ async def context2():
107108
# Assert that THIS context doesn't have any variables defined.
108109
names = CallContext.context_variable_names()
109110
assert len(names) == 0
111+
112+
113+
async def test_callcontext_with_copy_context_run():
114+
"""
115+
Test scenario:
116+
- The upper layer uses copy_context().run()
117+
- Multiple contexts concurrently modify CallContext
118+
- Verify that no context pollution occurs
119+
"""
120+
121+
async def context_task(name, value, delay):
122+
"""Coroutine task that modifies CallContext and validates its own values"""
123+
await asyncio.sleep(delay)
124+
CallContext.set(name, value)
125+
# Sleep again to simulate interleaving execution
126+
await asyncio.sleep(0.1)
127+
assert CallContext.get(name) == value, f"{name} was polluted"
128+
# Ensure only the variable written by this context exists
129+
keys = CallContext.context_variable_names()
130+
assert name in keys
131+
assert len(keys) == 1
132+
133+
# Initialize a variable in the main context
134+
CallContext.set("foo", "bar3")
135+
136+
# Create two independent copy_context instances
137+
ctx1 = copy_context()
138+
ctx2 = copy_context()
139+
140+
# Run coroutines in their respective contexts
141+
fut1 = asyncio.create_task(ctx1.run(lambda: context_task("foo", "bar1", 0.0)))
142+
fut2 = asyncio.create_task(ctx2.run(lambda: context_task("foo", "bar2", 0.05)))
143+
await asyncio.gather(fut1, fut2)
144+
145+
# The main context should remain unaffected (still is empty)
146+
assert CallContext.get("foo") == "bar3"

0 commit comments

Comments
 (0)