@@ -378,6 +378,9 @@ def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None,
378378        self .commands_bnum  =  None  # The breakpoint number for which we are 
379379                                  # defining a list 
380380
381+         self .async_shim_frame  =  None 
382+         self .async_awaitable  =  None 
383+ 
381384        self ._chained_exceptions  =  tuple ()
382385        self ._chained_exception_index  =  0 
383386
@@ -393,6 +396,57 @@ def set_trace(self, frame=None, *, commands=None):
393396
394397        super ().set_trace (frame )
395398
399+     async  def  set_trace_async (self , frame = None , * , commands = None ):
400+         if  self .async_awaitable  is  not None :
401+             # We are already in a set_trace_async call, do not mess with it 
402+             return 
403+ 
404+         if  frame  is  None :
405+             frame  =  sys ._getframe ().f_back 
406+ 
407+         # We need set_trace to set up the basics, however, this will call 
408+         # set_stepinstr() will we need to compensate for, because we don't 
409+         # want to trigger on calls 
410+         self .set_trace (frame , commands = commands )
411+         # Changing the stopframe will disable trace dispatch on calls 
412+         self .stopframe  =  frame 
413+         # We need to stop tracing because we don't have the privilege to avoid 
414+         # triggering tracing functions as normal, as we are not already in 
415+         # tracing functions 
416+         self .stop_trace ()
417+ 
418+         self .async_shim_frame  =  sys ._getframe ()
419+         self .async_awaitable  =  None 
420+ 
421+         while  True :
422+             self .async_awaitable  =  None 
423+             # Simulate a trace event 
424+             # This should bring up pdb and make pdb believe it's debugging the 
425+             # caller frame 
426+             self .trace_dispatch (frame , "opcode" , None )
427+             if  self .async_awaitable  is  not None :
428+                 try :
429+                     if  self .breaks :
430+                         with  self .set_enterframe (frame ):
431+                             # set_continue requires enterframe to work 
432+                             self .set_continue ()
433+                         self .start_trace ()
434+                     await  self .async_awaitable 
435+                 except  Exception :
436+                     self ._error_exc ()
437+             else :
438+                 break 
439+ 
440+         self .async_shim_frame  =  None 
441+ 
442+         # start the trace (the actual command is already set by set_* calls) 
443+         if  self .returnframe  is  None  and  self .stoplineno  ==  - 1  and  not  self .breaks :
444+             # This means we did a continue without any breakpoints, we should not 
445+             # start the trace 
446+             return 
447+ 
448+         self .start_trace ()
449+ 
396450    def  sigint_handler (self , signum , frame ):
397451        if  self .allow_kbdint :
398452            raise  KeyboardInterrupt 
@@ -775,6 +829,20 @@ def _exec_in_closure(self, source, globals, locals):
775829
776830        return  True 
777831
832+     def  _exec_await (self , source , globals , locals ):
833+         """ Run source code that contains await by playing with async shim frame""" 
834+         # Put the source in an async function 
835+         source_async  =  (
836+             "async def __pdb_await():\n "  + 
837+             textwrap .indent (source , "    " ) +  '\n '  + 
838+             "    __pdb_locals.update(locals())" 
839+         )
840+         ns  =  globals  |  locals 
841+         # We use __pdb_locals to do write back 
842+         ns ["__pdb_locals" ] =  locals 
843+         exec (source_async , ns )
844+         self .async_awaitable  =  ns ["__pdb_await" ]()
845+ 
778846    def  default (self , line ):
779847        if  line [:1 ] ==  '!' : line  =  line [1 :].strip ()
780848        locals  =  self .curframe .f_locals 
@@ -820,8 +888,20 @@ def default(self, line):
820888                sys .stdout  =  save_stdout 
821889                sys .stdin  =  save_stdin 
822890                sys .displayhook  =  save_displayhook 
823-         except :
824-             self ._error_exc ()
891+         except  Exception  as  e :
892+             # Maybe it's an await expression/statement 
893+             if  (
894+                 isinstance (e , SyntaxError )
895+                 and  e .msg  ==  "'await' outside function" 
896+                 and  self .async_shim_frame  is  not None 
897+             ):
898+                 try :
899+                     self ._exec_await (buffer , globals , locals )
900+                     return  True 
901+                 except :
902+                     self ._error_exc ()
903+             else :
904+                 self ._error_exc ()
825905
826906    def  _replace_convenience_variables (self , line ):
827907        """Replace the convenience variables in 'line' with their values. 
@@ -2491,6 +2571,21 @@ def set_trace(*, header=None, commands=None):
24912571        pdb .message (header )
24922572    pdb .set_trace (sys ._getframe ().f_back , commands = commands )
24932573
2574+ async  def  set_trace_async (* , header = None , commands = None ):
2575+     """Enter the debugger at the calling stack frame, but in async mode. 
2576+ 
2577+     This should be used as await pdb.set_trace_async(). Users can do await 
2578+     if they enter the debugger with this function. Otherwise it's the same 
2579+     as set_trace(). 
2580+     """ 
2581+     if  Pdb ._last_pdb_instance  is  not None :
2582+         pdb  =  Pdb ._last_pdb_instance 
2583+     else :
2584+         pdb  =  Pdb (mode = 'inline' , backend = 'monitoring' )
2585+     if  header  is  not None :
2586+         pdb .message (header )
2587+     await  pdb .set_trace_async (sys ._getframe ().f_back , commands = commands )
2588+ 
24942589# Post-Mortem interface 
24952590
24962591def  post_mortem (t = None ):
0 commit comments