Skip to content

Commit 9fabf0c

Browse files
tweaks to new REPL (#3002)
1 parent 3350c11 commit 9fabf0c

File tree

2 files changed

+59
-17
lines changed

2 files changed

+59
-17
lines changed

src/trio/_repl.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
import trio
1414
import trio.lowlevel
15+
from trio._util import final
1516

1617

18+
@final
1719
class TrioInteractiveConsole(InteractiveConsole):
1820
# code.InteractiveInterpreter defines locals as Mapping[str, Any]
1921
# but when we pass this to FunctionType it expects a dict. So
@@ -25,25 +27,32 @@ def __init__(self, repl_locals: dict[str, object] | None = None):
2527
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
2628

2729
def runcode(self, code: types.CodeType) -> None:
28-
async def _runcode_in_trio() -> outcome.Outcome[object]:
29-
func = types.FunctionType(code, self.locals)
30-
if inspect.iscoroutinefunction(func):
31-
return await outcome.acapture(func)
30+
func = types.FunctionType(code, self.locals)
31+
if inspect.iscoroutinefunction(func):
32+
result = trio.from_thread.run(outcome.acapture, func)
33+
else:
34+
result = trio.from_thread.run_sync(outcome.capture, func)
35+
if isinstance(result, outcome.Error):
36+
# If it is SystemExit, quit the repl. Otherwise, print the traceback.
37+
# If there is a SystemExit inside a BaseExceptionGroup, it probably isn't
38+
# the user trying to quit the repl, but rather an error in the code. So, we
39+
# don't try to inspect groups for SystemExit. Instead, we just print and
40+
# return to the REPL.
41+
if isinstance(result.error, SystemExit):
42+
raise result.error
3243
else:
33-
return outcome.capture(func)
44+
# Inline our own version of self.showtraceback that can use
45+
# outcome.Error.error directly to print clean tracebacks.
46+
# This also means overriding self.showtraceback does nothing.
47+
sys.last_type, sys.last_value = type(result.error), result.error
48+
sys.last_traceback = result.error.__traceback__
49+
# see https://docs.python.org/3/library/sys.html#sys.last_exc
50+
if sys.version_info >= (3, 12):
51+
sys.last_exc = result.error
3452

35-
try:
36-
trio.from_thread.run(_runcode_in_trio).unwrap()
37-
except SystemExit:
38-
# If it is SystemExit quit the repl. Otherwise, print the
39-
# traceback.
40-
# There could be a SystemExit inside a BaseExceptionGroup. If
41-
# that happens, it probably isn't the user trying to quit the
42-
# repl, but an error in the code. So we print the exception
43-
# and stay in the repl.
44-
raise
45-
except BaseException:
46-
self.showtraceback()
53+
# We always use sys.excepthook, unlike other implementations.
54+
# This means that overriding self.write also does nothing to tbs.
55+
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)
4756

4857

4958
async def run_repl(console: TrioInteractiveConsole) -> None:

src/trio/_tests/test_repl.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,35 @@ async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) ->
9696
await trio._repl.run_repl(console)
9797

9898

99+
async def test_KI_interrupts(
100+
capsys: pytest.CaptureFixture[str],
101+
monkeypatch: pytest.MonkeyPatch,
102+
) -> None:
103+
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
104+
raw_input = build_raw_input(
105+
[
106+
"from trio._util import signal_raise",
107+
"import signal, trio, trio.lowlevel",
108+
"async def f():",
109+
" trio.lowlevel.spawn_system_task("
110+
" trio.to_thread.run_sync,"
111+
" signal_raise,signal.SIGINT,"
112+
" )", # just awaiting this kills the test runner?!
113+
" await trio.sleep_forever()",
114+
" print('should not see this')",
115+
"",
116+
"await f()",
117+
"print('AFTER KeyboardInterrupt')",
118+
]
119+
)
120+
monkeypatch.setattr(console, "raw_input", raw_input)
121+
await trio._repl.run_repl(console)
122+
out, err = capsys.readouterr()
123+
assert "KeyboardInterrupt" in err
124+
assert "should" not in out
125+
assert "AFTER KeyboardInterrupt" in out
126+
127+
99128
async def test_system_exits_in_exc_group(
100129
capsys: pytest.CaptureFixture[str],
101130
monkeypatch: pytest.MonkeyPatch,
@@ -158,6 +187,8 @@ async def test_base_exception_captured(
158187
monkeypatch.setattr(console, "raw_input", raw_input)
159188
await trio._repl.run_repl(console)
160189
out, err = capsys.readouterr()
190+
assert "_threads.py" not in err
191+
assert "_repl.py" not in err
161192
assert "AFTER BaseException" in out
162193

163194

@@ -198,6 +229,8 @@ async def test_base_exception_capture_from_coroutine(
198229
monkeypatch.setattr(console, "raw_input", raw_input)
199230
await trio._repl.run_repl(console)
200231
out, err = capsys.readouterr()
232+
assert "_threads.py" not in err
233+
assert "_repl.py" not in err
201234
assert "AFTER BaseException" in out
202235

203236

0 commit comments

Comments
 (0)