Skip to content

Commit 1b77cdb

Browse files
Add set_trace_async for async support
1 parent 9634085 commit 1b77cdb

File tree

2 files changed

+161
-2
lines changed

2 files changed

+161
-2
lines changed

Lib/pdb.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None,
378378
self.commands_bnum = None # The breakpoint number for which we are
379379
# defining a list
380380

381+
self.async_shim_frame = None
382+
self.async_awaitable = None
383+
381384
self._chained_exceptions = tuple()
382385
self._chained_exception_index = 0
383386

@@ -393,6 +396,57 @@ def set_trace(self, frame=None, *, commands=None):
393396

394397
super().set_trace(frame)
395398

399+
async def set_trace_async(self, frame=None, *, commands=None):
400+
if self.async_awaitable is not None:
401+
# We are already in a set_trace_async call, do not mess with it
402+
return
403+
404+
if frame is None:
405+
frame = sys._getframe().f_back
406+
407+
# We need set_trace to set up the basics, however, this will call
408+
# set_stepinstr() will we need to compensate for, because we don't
409+
# want to trigger on calls
410+
self.set_trace(frame, commands=commands)
411+
# Changing the stopframe will disable trace dispatch on calls
412+
self.stopframe = frame
413+
# We need to stop tracing because we don't have the privilege to avoid
414+
# triggering tracing functions as normal, as we are not already in
415+
# tracing functions
416+
self.stop_trace()
417+
418+
self.async_shim_frame = sys._getframe()
419+
self.async_awaitable = None
420+
421+
while True:
422+
self.async_awaitable = None
423+
# Simulate a trace event
424+
# This should bring up pdb and make pdb believe it's debugging the
425+
# caller frame
426+
self.trace_dispatch(frame, "opcode", None)
427+
if self.async_awaitable is not None:
428+
try:
429+
if self.breaks:
430+
with self.set_enterframe(frame):
431+
# set_continue requires enterframe to work
432+
self.set_continue()
433+
self.start_trace()
434+
await self.async_awaitable
435+
except Exception:
436+
self._error_exc()
437+
else:
438+
break
439+
440+
self.async_shim_frame = None
441+
442+
# start the trace (the actual command is already set by set_* calls)
443+
if self.returnframe is None and self.stoplineno == -1 and not self.breaks:
444+
# This means we did a continue without any breakpoints, we should not
445+
# start the trace
446+
return
447+
448+
self.start_trace()
449+
396450
def sigint_handler(self, signum, frame):
397451
if self.allow_kbdint:
398452
raise KeyboardInterrupt
@@ -775,6 +829,20 @@ def _exec_in_closure(self, source, globals, locals):
775829

776830
return True
777831

832+
def _exec_await(self, source, globals, locals):
833+
""" Run source code that contains await by playing with async shim frame"""
834+
# Put the source in an async function
835+
source_async = (
836+
"async def __pdb_await():\n" +
837+
textwrap.indent(source, " ") + '\n' +
838+
" __pdb_locals.update(locals())"
839+
)
840+
ns = globals | locals
841+
# We use __pdb_locals to do write back
842+
ns["__pdb_locals"] = locals
843+
exec(source_async, ns)
844+
self.async_awaitable = ns["__pdb_await"]()
845+
778846
def default(self, line):
779847
if line[:1] == '!': line = line[1:].strip()
780848
locals = self.curframe.f_locals
@@ -820,8 +888,20 @@ def default(self, line):
820888
sys.stdout = save_stdout
821889
sys.stdin = save_stdin
822890
sys.displayhook = save_displayhook
823-
except:
824-
self._error_exc()
891+
except Exception as e:
892+
# Maybe it's an await expression/statement
893+
if (
894+
isinstance(e, SyntaxError)
895+
and e.msg == "'await' outside function"
896+
and self.async_shim_frame is not None
897+
):
898+
try:
899+
self._exec_await(buffer, globals, locals)
900+
return True
901+
except:
902+
self._error_exc()
903+
else:
904+
self._error_exc()
825905

826906
def _replace_convenience_variables(self, line):
827907
"""Replace the convenience variables in 'line' with their values.
@@ -2491,6 +2571,21 @@ def set_trace(*, header=None, commands=None):
24912571
pdb.message(header)
24922572
pdb.set_trace(sys._getframe().f_back, commands=commands)
24932573

2574+
async def set_trace_async(*, header=None, commands=None):
2575+
"""Enter the debugger at the calling stack frame, but in async mode.
2576+
2577+
This should be used as await pdb.set_trace_async(). Users can do await
2578+
if they enter the debugger with this function. Otherwise it's the same
2579+
as set_trace().
2580+
"""
2581+
if Pdb._last_pdb_instance is not None:
2582+
pdb = Pdb._last_pdb_instance
2583+
else:
2584+
pdb = Pdb(mode='inline', backend='monitoring')
2585+
if header is not None:
2586+
pdb.message(header)
2587+
await pdb.set_trace_async(sys._getframe().f_back, commands=commands)
2588+
24942589
# Post-Mortem interface
24952590

24962591
def post_mortem(t=None):

Lib/test/test_pdb.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,6 +2169,70 @@ def test_pdb_asynctask():
21692169
(Pdb) continue
21702170
"""
21712171

2172+
def test_pdb_await_support():
2173+
"""Testing await support in pdb
2174+
2175+
>>> import asyncio
2176+
2177+
>>> async def test():
2178+
... print("hello")
2179+
... await asyncio.sleep(0)
2180+
... print("world")
2181+
... return 42
2182+
2183+
>>> async def main():
2184+
... import pdb;
2185+
... task = asyncio.create_task(test())
2186+
... await pdb.Pdb(nosigint=True, readrc=False).set_trace_async()
2187+
... pass
2188+
2189+
>>> def test_function():
2190+
... asyncio.run(main(), loop_factory=asyncio.EventLoop)
2191+
2192+
>>> with PdbTestInput([ # doctest: +ELLIPSIS
2193+
... 'x = await task',
2194+
... 'p x',
2195+
... 'x = await test()',
2196+
... 'p x',
2197+
... 'new_task = asyncio.create_task(test())',
2198+
... 'await new_task',
2199+
... 'await non_exist()',
2200+
... 's',
2201+
... 'continue',
2202+
... ]):
2203+
... test_function()
2204+
> <doctest test.test_pdb.test_pdb_await_support[2]>(4)main()
2205+
-> await pdb.Pdb(nosigint=True, readrc=False).set_trace_async()
2206+
(Pdb) x = await task
2207+
hello
2208+
world
2209+
> <doctest test.test_pdb.test_pdb_await_support[2]>(4)main()
2210+
-> await pdb.Pdb(nosigint=True, readrc=False).set_trace_async()
2211+
(Pdb) p x
2212+
42
2213+
(Pdb) x = await test()
2214+
hello
2215+
world
2216+
> <doctest test.test_pdb.test_pdb_await_support[2]>(4)main()
2217+
-> await pdb.Pdb(nosigint=True, readrc=False).set_trace_async()
2218+
(Pdb) p x
2219+
42
2220+
(Pdb) new_task = asyncio.create_task(test())
2221+
(Pdb) await new_task
2222+
hello
2223+
world
2224+
> <doctest test.test_pdb.test_pdb_await_support[2]>(4)main()
2225+
-> await pdb.Pdb(nosigint=True, readrc=False).set_trace_async()
2226+
(Pdb) await non_exist()
2227+
*** NameError: name 'non_exist' is not defined
2228+
> <doctest test.test_pdb.test_pdb_await_support[2]>(4)main()
2229+
-> await pdb.Pdb(nosigint=True, readrc=False).set_trace_async()
2230+
(Pdb) s
2231+
> <doctest test.test_pdb.test_pdb_await_support[2]>(5)main()
2232+
-> pass
2233+
(Pdb) continue
2234+
"""
2235+
21722236
def test_pdb_next_command_for_coroutine():
21732237
"""Testing skip unwinding stack on yield for coroutines for "next" command
21742238

0 commit comments

Comments
 (0)