Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3007.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make ctrl+c work in more situations in the Trio REPL (``python -m trio``).
75 changes: 72 additions & 3 deletions src/trio/_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import contextlib
import inspect
import sys
import types
import warnings
from code import InteractiveConsole
from types import CodeType, FrameType, FunctionType
from typing import Callable

import outcome

Expand All @@ -15,14 +16,33 @@
from trio._util import final


class SuppressDecorator(contextlib.ContextDecorator, contextlib.suppress):
pass


@SuppressDecorator(KeyboardInterrupt)
@trio.lowlevel.disable_ki_protection
def terminal_newline() -> None: # TODO: test this line
import fcntl
import termios

# Fake up a newline char as if user had typed it at the terminal
try:
fcntl.ioctl(sys.stdin, termios.TIOCSTI, b"\n") # type: ignore[attr-defined, unused-ignore]
except OSError as e:
print(f"\nPress enter! Newline injection failed: {e}", end="", flush=True)


@final
class TrioInteractiveConsole(InteractiveConsole):
def __init__(self, repl_locals: dict[str, object] | None = None) -> None:
super().__init__(locals=repl_locals)
self.token: trio.lowlevel.TrioToken | None = None
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
self.interrupted = False

def runcode(self, code: types.CodeType) -> None:
func = types.FunctionType(code, self.locals)
def runcode(self, code: CodeType) -> None:
func = FunctionType(code, self.locals)
if inspect.iscoroutinefunction(func):
result = trio.from_thread.run(outcome.acapture, func)
else:
Expand All @@ -48,6 +68,55 @@ def runcode(self, code: types.CodeType) -> None:
# We always use sys.excepthook, unlike other implementations.
# This means that overriding self.write also does nothing to tbs.
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)
# clear any residual KI
trio.from_thread.run(trio.lowlevel.checkpoint_if_cancelled)
# trio.from_thread.check_cancelled() has too long of a memory

if sys.platform == "win32": # TODO: test this line

def raw_input(self, prompt: str = "") -> str:
try:
return input(prompt)
except EOFError:
# check if trio has a pending KI
trio.from_thread.run(trio.lowlevel.checkpoint_if_cancelled)
raise

else:

def raw_input(self, prompt: str = "") -> str:
from signal import SIGINT, signal

assert not self.interrupted

def install_handler() -> (
Callable[[int, FrameType | None], None] | int | None
):
def handler(
sig: int, frame: FrameType | None
) -> None: # TODO: test this line
self.interrupted = True
token.run_sync_soon(terminal_newline, idempotent=True)

token = trio.lowlevel.current_trio_token()

return signal(SIGINT, handler)

prev_handler = trio.from_thread.run_sync(install_handler)
try:
return input(prompt)
finally:
trio.from_thread.run_sync(signal, SIGINT, prev_handler)
if self.interrupted: # TODO: test this line
raise KeyboardInterrupt

def write(self, output: str) -> None:
if self.interrupted: # TODO: test this line
assert output == "\nKeyboardInterrupt\n"
sys.stderr.write(output[1:])
self.interrupted = False
else:
sys.stderr.write(output)


async def run_repl(console: TrioInteractiveConsole) -> None:
Expand Down
194 changes: 194 additions & 0 deletions src/trio/_tests/test_repl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import os
import pathlib
import signal
import subprocess
import sys
from functools import partial
from typing import Protocol

import pytest
Expand Down Expand Up @@ -239,3 +243,193 @@ def test_main_entrypoint() -> None:
"""
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
assert repl.returncode == 0


def should_try_newline_injection() -> bool:
if sys.platform != "linux":
return False

sysctl = pathlib.Path("/proc/sys/dev/tty/legacy_tiocsti")
if not sysctl.exists(): # pragma: no cover
return True

else:
return sysctl.read_text() == "1"


@pytest.mark.skipif(
not should_try_newline_injection(),
reason="the ioctl we use is disabled in CI",
)
def test_ki_newline_injection() -> None: # TODO: test this line
# TODO: we want to remove this functionality, eg by using vendored
# pyrepls.
assert sys.platform != "win32"

import pty

# NOTE: this cannot be subprocess.Popen because pty.fork
# does some magic to set the controlling terminal.
# (which I don't know how to replicate... so I copied this
# structure from pty.spawn...)
pid, pty_fd = pty.fork() # type: ignore[attr-defined,unused-ignore]
if pid == 0:
os.execlp(sys.executable, *[sys.executable, "-u", "-m", "trio"])

# setup:
buffer = b""
while not buffer.endswith(b"import trio\r\n>>> "):
buffer += os.read(pty_fd, 4096)

# sanity check:
print(buffer.decode())
buffer = b""
os.write(pty_fd, b'print("hello!")\n')
while not buffer.endswith(b">>> "):
buffer += os.read(pty_fd, 4096)

assert buffer.count(b"hello!") == 2

# press ctrl+c
print(buffer.decode())
buffer = b""
os.kill(pid, signal.SIGINT)
while not buffer.endswith(b">>> "):
buffer += os.read(pty_fd, 4096)

assert b"KeyboardInterrupt" in buffer

# press ctrl+c later
print(buffer.decode())
buffer = b""
os.write(pty_fd, b'print("hello!")')
os.kill(pid, signal.SIGINT)
while not buffer.endswith(b">>> "):
buffer += os.read(pty_fd, 4096)

assert b"KeyboardInterrupt" in buffer
print(buffer.decode())
os.close(pty_fd)
os.waitpid(pid, 0)[1]


async def test_ki_in_repl() -> None:
async with trio.open_nursery() as nursery:
proc = await nursery.start(
partial(
trio.run_process,
[sys.executable, "-u", "-m", "trio"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0, # type: ignore[attr-defined,unused-ignore]
)
)

async with proc.stdout:
# setup
buffer = b""
async for part in proc.stdout: # pragma: no branch
buffer += part
# TODO: consider making run_process stdout have some universal newlines thing
if buffer.replace(b"\r\n", b"\n").endswith(b"import trio\n>>> "):
break

# ensure things work
print(buffer.decode())
buffer = b""
await proc.stdin.send_all(b'print("hello!")\n')
async for part in proc.stdout: # pragma: no branch
buffer += part
if buffer.endswith(b">>> "):
break

assert b"hello!" in buffer
print(buffer.decode())

# this seems to be necessary on Windows for reasons
# (the parents of process groups ignore ctrl+c by default...)
if sys.platform == "win32":
buffer = b""
await proc.stdin.send_all(
b"import ctypes; ctypes.windll.kernel32.SetConsoleCtrlHandler(None, False)\n"
)
async for part in proc.stdout: # pragma: no branch
buffer += part
if buffer.endswith(b">>> "):
break

print(buffer.decode())

# try to decrease flakiness...
buffer = b""
await proc.stdin.send_all(
b"import coverage; trio.lowlevel.enable_ki_protection(coverage.pytracer.PyTracer._trace)\n"
)
async for part in proc.stdout: # pragma: no branch
buffer += part
if buffer.endswith(b">>> "):
break

print(buffer.decode())

# ensure that ctrl+c on a prompt works
# NOTE: for some reason, signal.SIGINT doesn't work for this test.
# Using CTRL_C_EVENT is also why we need subprocess.CREATE_NEW_PROCESS_GROUP
signal_sent = signal.CTRL_C_EVENT if sys.platform == "win32" else signal.SIGINT # type: ignore[attr-defined,unused-ignore]
os.kill(proc.pid, signal_sent)
if sys.platform == "win32":
# we rely on EOFError which... doesn't happen with pipes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never dug in to why EOFError pops out, maybe that knowledge could help us test here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't dug into this yet unfortunately.

# I'm not sure how to fix it...
await proc.stdin.send_all(b"\n")
else:
# we test injection separately
await proc.stdin.send_all(b"\n")

buffer = b""
async for part in proc.stdout: # pragma: no branch
buffer += part
if buffer.endswith(b">>> "):
break

assert b"KeyboardInterrupt" in buffer

# ensure ctrl+c while a command runs works
print(buffer.decode())
await proc.stdin.send_all(b'print("READY"); await trio.sleep_forever()\n')
killed = False
buffer = b""
async for part in proc.stdout: # pragma: no branch
buffer += part
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
os.kill(proc.pid, signal_sent)
killed = True
if buffer.endswith(b">>> "):
break

assert b"trio" in buffer
assert b"KeyboardInterrupt" in buffer

# make sure it works for sync commands too
# (though this would be hard to break)
print(buffer.decode())
await proc.stdin.send_all(
b'import time; print("READY"); time.sleep(99999)\n'
)
killed = False
buffer = b""
async for part in proc.stdout: # pragma: no branch
buffer += part
if buffer.replace(b"\r\n", b"\n").endswith(b"READY\n") and not killed:
os.kill(proc.pid, signal_sent)
killed = True
if buffer.endswith(b">>> "):
break

assert b"Traceback" in buffer
assert b"KeyboardInterrupt" in buffer

print(buffer.decode())

# kill the process
nursery.cancel_scope.cancel()
Loading