@@ -3,7 +3,7 @@ use std::ffi::CString;
33use std:: fmt:: Display ;
44
55use proc_macro2:: { Span , TokenStream } ;
6- use quote:: { format_ident , quote, quote_spanned, ToTokens } ;
6+ use quote:: { quote, quote_spanned, ToTokens } ;
77use syn:: LitCStr ;
88use syn:: { ext:: IdentExt , spanned:: Spanned , Ident , Result } ;
99
@@ -666,99 +666,116 @@ impl<'a> FnSpec<'a> {
666666 }
667667 }
668668
669- let rust_call = |args : Vec < TokenStream > , holders : & mut Holders | {
670- let mut self_arg = || self . tp . self_arg ( cls, ExtractErrorMode :: Raise , holders, ctx) ;
669+ let rust_call = |args : Vec < TokenStream > , mut holders : Holders | {
670+ let self_arg = self
671+ . tp
672+ . self_arg ( cls, ExtractErrorMode :: Raise , & mut holders, ctx) ;
673+ let init_holders = holders. init_holders ( ctx) ;
671674
672- let call = if self . asyncness . is_some ( ) {
673- let throw_callback = if cancel_handle. is_some ( ) {
674- quote ! { Some ( __throw_callback) }
675+ // We must assign the output_span to the return value of the call,
676+ // but *not* of the call itself otherwise the spans get really weird
677+ let ret_ident = Ident :: new ( "ret" , * output_span) ;
678+
679+ if self . asyncness . is_some ( ) {
680+ // For async functions, we need to build up a coroutine object to return from the initial function call.
681+ //
682+ // Extraction of the call signature (positional & keyword arguments) happens as part of the initial function
683+ // call. The Python objects are then moved into the Rust future that will be executed when the coroutine is
684+ // awaited.
685+ //
686+ // The argument extraction from Python objects to Rust values then happens inside the future, this allows
687+ // things like extraction to `&MyClass` which needs a holder (for the class guard) to work properly inside
688+ // async code.
689+ //
690+ // It *might* be possible in the future to do the extraction before the coroutine is created, but that would require
691+ // changing argument extraction code to first create holders and then read the values from them later.
692+ let ( throw_callback, init_throw_callback) = if cancel_handle. is_some ( ) {
693+ (
694+ quote ! { Some ( __throw_callback) } ,
695+ Some (
696+ quote ! { let __cancel_handle = #pyo3_path:: coroutine:: CancelHandle :: new( ) ;
697+ let __throw_callback = __cancel_handle. throw_callback( ) ; } ,
698+ ) ,
699+ )
675700 } else {
676- quote ! { None }
701+ ( quote ! { None } , None )
677702 } ;
678703 let python_name = & self . python_name ;
679704 let qualname_prefix = match cls {
680705 Some ( cls) => quote ! ( Some ( <#cls as #pyo3_path:: PyClass >:: NAME ) ) ,
681706 None => quote ! ( None ) ,
682707 } ;
683- let future = match self . tp {
684- // If extracting `self`, we move the `_slf` pointer into the async block. This reduces the lifetime for which the Rust state is considered "borrowed"
685- // to just when the async block is executing.
686- //
687- // TODO: we should do this with all arguments, not just `self`, e.g. https://github.com/PyO3/pyo3/issues/5681
688- FnType :: Fn ( SelfType :: Receiver { mutable, .. } ) => {
689- let arg_names = ( 0 ..args. len ( ) )
690- . map ( |i| format_ident ! ( "arg_{}" , i) )
691- . collect :: < Vec < _ > > ( ) ;
692- let method = syn:: Ident :: new (
693- if mutable {
694- "extract_pyclass_ref_mut"
695- } else {
696- "extract_pyclass_ref"
697- } ,
698- Span :: call_site ( ) ,
699- ) ;
700- quote ! { {
701- let _slf = unsafe { #pyo3_path:: impl_:: extract_argument:: cast_function_argument( py, _slf) } . to_owned( ) . unbind( ) ;
702- #( let #arg_names = #args; ) *
703- async move {
704- // SAFETY: attached when future is polled (see `Coroutine::poll`)
705- let assume_attached = unsafe { #pyo3_path:: impl_:: coroutine:: AssumeAttachedInCoroutine :: new( ) } ;
706- let py = assume_attached. py( ) ;
707- let mut holder = None ;
708- let future = function(
709- #pyo3_path:: impl_:: extract_argument:: #method( _slf. bind_borrowed( py) , & mut holder) ?,
710- #( #arg_names) , *
711- ) ;
712- drop( py) ;
713- let result = future. await ;
714- let result: #pyo3_path:: PyResult <_> = #pyo3_path:: impl_:: wrap:: converter( & result) . wrap( result) . map_err( :: std:: convert:: Into :: into) ;
715- result
716- }
717- } }
718- }
719- _ => {
720- let args = self_arg ( ) . into_iter ( ) . chain ( args) ;
721- quote ! { function( #( #args) , * ) }
722- }
708+ // copy self arg into async block
709+ // slf_py will create the owned value to store in the future
710+ // slf_ptr recreates the raw pointer temporarily when building the future
711+ let ( slf_py, slf_ptr) = if self_arg. is_some ( ) {
712+ (
713+ Some (
714+ quote ! { let _slf = #pyo3_path:: Borrowed :: from_ptr( py, _slf) . to_owned( ) . unbind( ) ; } ,
715+ ) ,
716+ Some ( quote ! { let _slf = _slf. as_ptr( ) ; } ) ,
717+ )
718+ } else {
719+ ( None , None )
723720 } ;
724- let mut call = quote ! { {
725- let future = #future;
726- #pyo3_path:: impl_:: coroutine:: new_coroutine(
727- #pyo3_path:: intern!( py, stringify!( #python_name) ) ,
728- #qualname_prefix,
729- #throw_callback,
730- async move {
731- // SAFETY: attached when future is polled (see `Coroutine::poll`)
732- let assume_attached = unsafe { #pyo3_path:: impl_:: coroutine:: AssumeAttachedInCoroutine :: new( ) } ;
733- let output = future. await ;
734- let res = #pyo3_path:: impl_:: wrap:: converter( & output) . wrap( output) . map_err( :: std:: convert:: Into :: into) ;
735- #pyo3_path:: impl_:: wrap:: converter( & res) . map_into_pyobject( assume_attached. py( ) , res)
736- } ,
721+ // copy extracted arguments into async block
722+ // output_py will create the owned arguments to store in the future
723+ // output_args recreates the borrowed objects temporarily when building the future
724+ let ( output_py, output_args) = if !matches ! ( convention, CallingConvention :: Noargs ) {
725+ (
726+ Some ( quote ! {
727+ let output = output. map( |o| o. map( Py :: from) ) ;
728+ } ) ,
729+ Some ( quote ! {
730+ let output = output. each_ref( ) . map( |o| o. as_ref( ) . map( |obj| obj. bind_borrowed( assume_attached. py( ) ) ) ) ;
731+ } ) ,
737732 )
738- } } ;
739- if cancel_handle. is_some ( ) {
740- call = quote ! { {
741- let __cancel_handle = #pyo3_path:: coroutine:: CancelHandle :: new( ) ;
742- let __throw_callback = __cancel_handle. throw_callback( ) ;
743- #call
744- } } ;
733+ } else {
734+ ( None , None )
735+ } ;
736+ let args = self_arg. into_iter ( ) . chain ( args) ;
737+ let ok_wrap = quotes:: ok_wrap ( ret_ident. to_token_stream ( ) , ctx) ;
738+ quote ! {
739+ {
740+ let coroutine = {
741+ #slf_py
742+ #output_py
743+ #init_throw_callback
744+ #pyo3_path:: impl_:: coroutine:: new_coroutine(
745+ #pyo3_path:: intern!( py, stringify!( #python_name) ) ,
746+ #qualname_prefix,
747+ #throw_callback,
748+ async move {
749+ // SAFETY: attached when future is polled (see `Coroutine::poll`)
750+ let assume_attached = unsafe { #pyo3_path:: impl_:: coroutine:: AssumeAttachedInCoroutine :: new( ) } ;
751+ #init_holders
752+ let future = {
753+ let py = assume_attached. py( ) ;
754+ #slf_ptr
755+ #output_args
756+ function( #( #args) , * )
757+ } ;
758+ let #ret_ident = future. await ;
759+ let #ret_ident = #ok_wrap;
760+ #pyo3_path:: impl_:: wrap:: converter( & #ret_ident) . map_into_pyobject( assume_attached. py( ) , #ret_ident)
761+ } ,
762+ )
763+ } ;
764+ #pyo3_path:: Py :: new( py, coroutine) . map( #pyo3_path:: Py :: into_ptr)
765+ }
745766 }
746- call
747767 } else {
748- let args = self_arg ( ) . into_iter ( ) . chain ( args) ;
749- quote ! { function( #( #args) , * ) }
750- } ;
751-
752- // We must assign the output_span to the return value of the call,
753- // but *not* of the call itself otherwise the spans get really weird
754- let ret_ident = Ident :: new ( "ret" , * output_span) ;
755- let ret_expr = quote ! { let #ret_ident = #call; } ;
756- let return_conversion =
757- quotes:: map_result_into_ptr ( quotes:: ok_wrap ( ret_ident. to_token_stream ( ) , ctx) , ctx) ;
758- quote ! {
759- {
760- #ret_expr
761- #return_conversion
768+ let args = self_arg. into_iter ( ) . chain ( args) ;
769+ let return_conversion = quotes:: map_result_into_ptr (
770+ quotes:: ok_wrap ( ret_ident. to_token_stream ( ) , ctx) ,
771+ ctx,
772+ ) ;
773+ quote ! {
774+ {
775+ #init_holders
776+ let #ret_ident = function( #( #args) , * ) ;
777+ #return_conversion
778+ }
762779 }
763780 }
764781 } ;
@@ -771,10 +788,10 @@ impl<'a> FnSpec<'a> {
771788 } ;
772789
773790 let warnings = self . warnings . build_py_warning ( ctx) ;
791+ let mut holders = Holders :: new ( ) ;
774792
775793 Ok ( match convention {
776794 CallingConvention :: Noargs => {
777- let mut holders = Holders :: new ( ) ;
778795 let args = self
779796 . signature
780797 . arguments
@@ -785,26 +802,22 @@ impl<'a> FnSpec<'a> {
785802 _ => unreachable ! ( "`CallingConvention::Noargs` should not contain any arguments (reaching Python) except for `self`, which is handled below." ) ,
786803 } )
787804 . collect ( ) ;
788- let call = rust_call ( args, & mut holders) ;
789- let init_holders = holders. init_holders ( ctx) ;
805+ let call = rust_call ( args, holders) ;
790806 quote ! {
791807 unsafe fn #ident<' py>(
792808 py: #pyo3_path:: Python <' py>,
793809 _slf: * mut #pyo3_path:: ffi:: PyObject ,
794810 ) -> #pyo3_path:: PyResult <* mut #pyo3_path:: ffi:: PyObject > {
795811 let function = #rust_name; // Shadow the function name to avoid #3017
796- #init_holders
797812 #warnings
798813 let result = #call;
799814 result
800815 }
801816 }
802817 }
803818 CallingConvention :: Fastcall => {
804- let mut holders = Holders :: new ( ) ;
805819 let ( arg_convert, args) = impl_arg_params ( self , cls, true , & mut holders, ctx) ;
806- let call = rust_call ( args, & mut holders) ;
807- let init_holders = holders. init_holders ( ctx) ;
820+ let call = rust_call ( args, holders) ;
808821
809822 quote ! {
810823 unsafe fn #ident<' py>(
@@ -816,18 +829,15 @@ impl<'a> FnSpec<'a> {
816829 ) -> #pyo3_path:: PyResult <* mut #pyo3_path:: ffi:: PyObject > {
817830 let function = #rust_name; // Shadow the function name to avoid #3017
818831 #arg_convert
819- #init_holders
820832 #warnings
821833 let result = #call;
822834 result
823835 }
824836 }
825837 }
826838 CallingConvention :: Varargs => {
827- let mut holders = Holders :: new ( ) ;
828839 let ( arg_convert, args) = impl_arg_params ( self , cls, false , & mut holders, ctx) ;
829- let call = rust_call ( args, & mut holders) ;
830- let init_holders = holders. init_holders ( ctx) ;
840+ let call = rust_call ( args, holders) ;
831841
832842 quote ! {
833843 unsafe fn #ident<' py>(
@@ -838,7 +848,6 @@ impl<'a> FnSpec<'a> {
838848 ) -> #pyo3_path:: PyResult <* mut #pyo3_path:: ffi:: PyObject > {
839849 let function = #rust_name; // Shadow the function name to avoid #3017
840850 #arg_convert
841- #init_holders
842851 #warnings
843852 let result = #call;
844853 result
0 commit comments