Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3007.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make ctrl+c work in more situations in the Trio REPL (``python -m trio``).
25 changes: 18 additions & 7 deletions src/trio/_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import warnings
from code import InteractiveConsole
from types import CodeType, FunctionType, FrameType
from types import CodeType, FrameType, FunctionType
from typing import Callable

import outcome
Expand All @@ -27,7 +27,9 @@ def terminal_newline() -> None:
import termios

# Fake up a newline char as if user had typed it at the terminal
fcntl.ioctl(sys.stdin, termios.TIOCSTI, b"\n")
# on a best-effort basis
with contextlib.suppress(OSError):
fcntl.ioctl(sys.stdin, termios.TIOCSTI, b"\n") # type: ignore[attr-defined, unused-ignore]


@final
Expand All @@ -36,9 +38,11 @@ def __init__(self, repl_locals: dict[str, object] | None = None) -> None:
super().__init__(locals=repl_locals)
self.token: trio.lowlevel.TrioToken | None = None
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
self.interrupted = False

def runcode(self, code: CodeType) -> None:
func = FunctionType(code, self.locals)
# https://github.com/python/typeshed/issues/13768
func = FunctionType(code, self.locals) # type: ignore[arg-type]
if inspect.iscoroutinefunction(func):
result = trio.from_thread.run(outcome.acapture, func)
else:
Expand Down Expand Up @@ -83,14 +87,13 @@ def raw_input(self, prompt: str = "") -> str:
def raw_input(self, prompt: str = "") -> str:
from signal import SIGINT, signal

interrupted = False
assert not self.interrupted

def install_handler() -> (
Callable[[int, FrameType | None], None] | int | None
):
def handler(sig: int, frame: FrameType | None) -> None:
nonlocal interrupted
interrupted = True
self.interrupted = True
token.run_sync_soon(terminal_newline, idempotent=True)

token = trio.lowlevel.current_trio_token()
Expand All @@ -102,9 +105,17 @@ def handler(sig: int, frame: FrameType | None) -> None:
return input(prompt)
finally:
trio.from_thread.run_sync(signal, SIGINT, prev_handler)
if interrupted:
if self.interrupted:
raise KeyboardInterrupt

def write(self, output: str) -> None:
if self.interrupted:
assert output == "\nKeyboardInterrupt\n"
sys.stderr.write(output[1:])
self.interrupted = False
else:
sys.stderr.write(output)


async def run_repl(console: TrioInteractiveConsole) -> None:
banner = (
Expand Down
179 changes: 179 additions & 0 deletions src/trio/_tests/test_repl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import os
import signal
import subprocess
import sys
from functools import partial
from typing import Protocol

import pytest
Expand Down Expand Up @@ -239,3 +242,179 @@ def test_main_entrypoint() -> None:
"""
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
assert repl.returncode == 0


# TODO: skip this based on sysctls? Or Linux version?
@pytest.mark.skipif(True, reason="the ioctl we use is disabled in CI")
def test_ki_newline_injection() -> None:
# TODO: we want to remove this functionality, eg by using vendored
# pyrepls.
assert sys.platform != "win32"

import pty

# NOTE: this cannot be subprocess.Popen because pty.fork
# does some magic to set the controlling terminal.
# (which I don't know how to replicate... so I copied this
# structure from pty.spawn...)
pid, pty_fd = pty.fork() # type: ignore[attr-defined,unused-ignore]
if pid == 0:
os.execlp(sys.executable, *[sys.executable, "-u", "-m", "trio"])

# setup:
buffer = b""
while not buffer.endswith(b"import trio\r\n>>> "):
buffer += os.read(pty_fd, 4096)

# sanity check:
print(buffer.decode())
buffer = b""
os.write(pty_fd, b'print("hello!")\n')
while not buffer.endswith(b">>> "):
buffer += os.read(pty_fd, 4096)

assert buffer.count(b"hello!") == 2

# press ctrl+c
print(buffer.decode())
buffer = b""
os.kill(pid, signal.SIGINT)
while not buffer.endswith(b">>> "):
buffer += os.read(pty_fd, 4096)

assert b"KeyboardInterrupt" in buffer

# press ctrl+c later
print(buffer.decode())
buffer = b""
os.write(pty_fd, b'print("hello!")')
os.kill(pid, signal.SIGINT)
while not buffer.endswith(b">>> "):
buffer += os.read(pty_fd, 4096)

assert b"KeyboardInterrupt" in buffer
print(buffer.decode())
os.close(pty_fd)
os.waitpid(pid, 0)[1]


async def test_ki_in_repl() -> None:
async with trio.open_nursery() as nursery:
proc = await nursery.start(
partial(
trio.run_process,
[sys.executable, "-u", "-m", "trio"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0, # type: ignore[attr-defined,unused-ignore]
)
)

async with proc.stdout:
# setup
buffer = b""
async for part in proc.stdout:
buffer += part
# TODO: consider making run_process stdout have some universal newlines thing
if buffer.replace(b"\r\n", b"\n").endswith(b"import trio\n>>> "):
break

# ensure things work
print(buffer.decode())
buffer = b""
await proc.stdin.send_all(b'print("hello!")\n')
async for part in proc.stdout:
buffer += part
if buffer.endswith(b">>> "):
break

assert b"hello!" in buffer
print(buffer.decode())

# this seems to be necessary on Windows for reasons
# (the parents of process groups ignore ctrl+c by default...)
if sys.platform == "win32":
buffer = b""
await proc.stdin.send_all(
b"import ctypes; ctypes.windll.kernel32.SetConsoleCtrlHandler(None, False)\n"
)
async for part in proc.stdout:
buffer += part
if buffer.endswith(b">>> "):
break

print(buffer.decode())

# try to decrease flakiness...
buffer = b""
await proc.stdin.send_all(
b"import coverage; trio.lowlevel.enable_ki_protection(coverage.pytracer.PyTracer._trace)\n"
)
async for part in proc.stdout:
buffer += part
if buffer.endswith(b">>> "):
break

print(buffer.decode())

# ensure that ctrl+c on a prompt works
# NOTE: for some reason, signal.SIGINT doesn't work for this test.
# Using CTRL_C_EVENT is also why we need subprocess.CREATE_NEW_PROCESS_GROUP
signal_sent = signal.CTRL_C_EVENT if sys.platform == "win32" else signal.SIGINT # type: ignore[attr-defined,unused-ignore]
os.kill(proc.pid, signal_sent)
if sys.platform == "win32":
# we rely on EOFError which... doesn't happen with pipes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never dug in to why EOFError pops out, maybe that knowledge could help us test here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't dug into this yet unfortunately.

# I'm not sure how to fix it...
await proc.stdin.send_all(b"\n")
else:
# we test injection separately
await proc.stdin.send_all(b"\n")

buffer = b""
async for part in proc.stdout:
buffer += part
if buffer.endswith(b">>> "):
break

assert b"KeyboardInterrupt" in buffer

# ensure ctrl+c while a command runs works
print(buffer.decode())
await proc.stdin.send_all(b'print("READY"); await trio.sleep_forever()\n')
killed = False
buffer = b""
async for part in proc.stdout:
buffer += part
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
os.kill(proc.pid, signal_sent)
killed = True
if buffer.endswith(b">>> "):
break

assert b"trio" in buffer
assert b"KeyboardInterrupt" in buffer

# make sure it works for sync commands too
# (though this would be hard to break)
print(buffer.decode())
await proc.stdin.send_all(
b'import time; print("READY"); time.sleep(99999)\n'
)
killed = False
buffer = b""
async for part in proc.stdout:
buffer += part
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
os.kill(proc.pid, signal_sent)
killed = True
if buffer.endswith(b">>> "):
break

assert b"Traceback" in buffer
assert b"KeyboardInterrupt" in buffer

print(buffer.decode())

# kill the process
nursery.cancel_scope.cancel()