Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Lib/asyncio/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import asyncio
import concurrent.futures
import contextvars
import inspect
import os
import site
Expand All @@ -22,6 +23,7 @@ def __init__(self, locals, loop):
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT

self.loop = loop
self.context = contextvars.copy_context()

def runcode(self, code):
global return_code
Expand Down Expand Up @@ -55,12 +57,12 @@ def callback():
return

try:
repl_future = self.loop.create_task(coro)
repl_future = self.loop.create_task(coro, context=self.context)
futures._chain_future(repl_future, future)
except BaseException as exc:
future.set_exception(exc)

loop.call_soon_threadsafe(callback)
loop.call_soon_threadsafe(callback, context=self.context)

try:
return future.result()
Expand Down
37 changes: 37 additions & 0 deletions Lib/test/test_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,5 +291,42 @@ def f():
self.assertEqual(traceback_lines, expected_lines)


class TestAsyncioREPLContextVars(unittest.TestCase):
def test_toplevel_contextvars_sync(self):
user_input = dedent("""\
from contextvars import ContextVar
var = ContextVar("var", default="failed")
var.set("ok")
""")
p = spawn_repl("-m", "asyncio")
p.stdin.write(user_input)
user_input2 = dedent("""
print(f"toplevel contextvar test: {var.get()}")
""")
p.stdin.write(user_input2)
output = kill_python(p)
self.assertEqual(p.returncode, 0)
expected = "toplevel contextvar test: ok"
self.assertIn(expected, output, expected)

def test_toplevel_contextvars_async(self):
user_input = dedent("""\
from contextvars import ContextVar
var = ContextVar('var', default='failed')
""")
p = spawn_repl("-m", "asyncio")
p.stdin.write(user_input)
user_input2 = "async def set_var(): var.set('ok')\n"
p.stdin.write(user_input2)
user_input3 = "await set_var()\n"
p.stdin.write(user_input3)
user_input4 = "print(f'toplevel contextvar test: {var.get()}')\n"
p.stdin.write(user_input4)
output = kill_python(p)
self.assertEqual(p.returncode, 0)
expected = "toplevel contextvar test: ok"
self.assertIn(expected, output, expected)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
All asyncio REPL prompts run in the same context
Loading