Skip to content

Commit 3f8f533

Browse files
authored
support borrowed values in async functions (#5725)
* support borrowed values in `async` functions * newsfragment
1 parent 185e6b3 commit 3f8f533

File tree

3 files changed

+187
-94
lines changed

3 files changed

+187
-94
lines changed

newsfragments/5725.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix compile error when using references to `#[pyclass]` types (e.g. `&MyClass`) as arguments to async `#[pyfunction]`s.

pyo3-macros-backend/src/method.rs

Lines changed: 103 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::ffi::CString;
33
use std::fmt::Display;
44

55
use proc_macro2::{Span, TokenStream};
6-
use quote::{format_ident, quote, quote_spanned, ToTokens};
6+
use quote::{quote, quote_spanned, ToTokens};
77
use syn::LitCStr;
88
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
99

@@ -666,99 +666,116 @@ impl<'a> FnSpec<'a> {
666666
}
667667
}
668668

669-
let rust_call = |args: Vec<TokenStream>, holders: &mut Holders| {
670-
let mut self_arg = || self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);
669+
let rust_call = |args: Vec<TokenStream>, mut holders: Holders| {
670+
let self_arg = self
671+
.tp
672+
.self_arg(cls, ExtractErrorMode::Raise, &mut holders, ctx);
673+
let init_holders = holders.init_holders(ctx);
671674

672-
let call = if self.asyncness.is_some() {
673-
let throw_callback = if cancel_handle.is_some() {
674-
quote! { Some(__throw_callback) }
675+
// We must assign the output_span to the return value of the call,
676+
// but *not* of the call itself otherwise the spans get really weird
677+
let ret_ident = Ident::new("ret", *output_span);
678+
679+
if self.asyncness.is_some() {
680+
// For async functions, we need to build up a coroutine object to return from the initial function call.
681+
//
682+
// Extraction of the call signature (positional & keyword arguments) happens as part of the initial function
683+
// call. The Python objects are then moved into the Rust future that will be executed when the coroutine is
684+
// awaited.
685+
//
686+
// The argument extraction from Python objects to Rust values then happens inside the future, this allows
687+
// things like extraction to `&MyClass` which needs a holder (for the class guard) to work properly inside
688+
// async code.
689+
//
690+
// It *might* be possible in the future to do the extraction before the coroutine is created, but that would require
691+
// changing argument extraction code to first create holders and then read the values from them later.
692+
let (throw_callback, init_throw_callback) = if cancel_handle.is_some() {
693+
(
694+
quote! { Some(__throw_callback) },
695+
Some(
696+
quote! { let __cancel_handle = #pyo3_path::coroutine::CancelHandle::new();
697+
let __throw_callback = __cancel_handle.throw_callback(); },
698+
),
699+
)
675700
} else {
676-
quote! { None }
701+
(quote! { None }, None)
677702
};
678703
let python_name = &self.python_name;
679704
let qualname_prefix = match cls {
680705
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyClass>::NAME)),
681706
None => quote!(None),
682707
};
683-
let future = match self.tp {
684-
// If extracting `self`, we move the `_slf` pointer into the async block. This reduces the lifetime for which the Rust state is considered "borrowed"
685-
// to just when the async block is executing.
686-
//
687-
// TODO: we should do this with all arguments, not just `self`, e.g. https://github.com/PyO3/pyo3/issues/5681
688-
FnType::Fn(SelfType::Receiver { mutable, .. }) => {
689-
let arg_names = (0..args.len())
690-
.map(|i| format_ident!("arg_{}", i))
691-
.collect::<Vec<_>>();
692-
let method = syn::Ident::new(
693-
if mutable {
694-
"extract_pyclass_ref_mut"
695-
} else {
696-
"extract_pyclass_ref"
697-
},
698-
Span::call_site(),
699-
);
700-
quote! {{
701-
let _slf = unsafe { #pyo3_path::impl_::extract_argument::cast_function_argument(py, _slf) }.to_owned().unbind();
702-
#(let #arg_names = #args;)*
703-
async move {
704-
// SAFETY: attached when future is polled (see `Coroutine::poll`)
705-
let assume_attached = unsafe { #pyo3_path::impl_::coroutine::AssumeAttachedInCoroutine::new() };
706-
let py = assume_attached.py();
707-
let mut holder = None;
708-
let future = function(
709-
#pyo3_path::impl_::extract_argument::#method(_slf.bind_borrowed(py), &mut holder)?,
710-
#(#arg_names),*
711-
);
712-
drop(py);
713-
let result = future.await;
714-
let result: #pyo3_path::PyResult<_> = #pyo3_path::impl_::wrap::converter(&result).wrap(result).map_err(::std::convert::Into::into);
715-
result
716-
}
717-
}}
718-
}
719-
_ => {
720-
let args = self_arg().into_iter().chain(args);
721-
quote! { function(#(#args),*) }
722-
}
708+
// copy self arg into async block
709+
// slf_py will create the owned value to store in the future
710+
// slf_ptr recreates the raw pointer temporarily when building the future
711+
let (slf_py, slf_ptr) = if self_arg.is_some() {
712+
(
713+
Some(
714+
quote! { let _slf = #pyo3_path::Borrowed::from_ptr(py, _slf).to_owned().unbind(); },
715+
),
716+
Some(quote! { let _slf = _slf.as_ptr(); }),
717+
)
718+
} else {
719+
(None, None)
723720
};
724-
let mut call = quote! {{
725-
let future = #future;
726-
#pyo3_path::impl_::coroutine::new_coroutine(
727-
#pyo3_path::intern!(py, stringify!(#python_name)),
728-
#qualname_prefix,
729-
#throw_callback,
730-
async move {
731-
// SAFETY: attached when future is polled (see `Coroutine::poll`)
732-
let assume_attached = unsafe { #pyo3_path::impl_::coroutine::AssumeAttachedInCoroutine::new() };
733-
let output = future.await;
734-
let res = #pyo3_path::impl_::wrap::converter(&output).wrap(output).map_err(::std::convert::Into::into);
735-
#pyo3_path::impl_::wrap::converter(&res).map_into_pyobject(assume_attached.py(), res)
736-
},
721+
// copy extracted arguments into async block
722+
// output_py will create the owned arguments to store in the future
723+
// output_args recreates the borrowed objects temporarily when building the future
724+
let (output_py, output_args) = if !matches!(convention, CallingConvention::Noargs) {
725+
(
726+
Some(quote! {
727+
let output = output.map(|o| o.map(Py::from));
728+
}),
729+
Some(quote! {
730+
let output = output.each_ref().map(|o| o.as_ref().map(|obj| obj.bind_borrowed(assume_attached.py())));
731+
}),
737732
)
738-
}};
739-
if cancel_handle.is_some() {
740-
call = quote! {{
741-
let __cancel_handle = #pyo3_path::coroutine::CancelHandle::new();
742-
let __throw_callback = __cancel_handle.throw_callback();
743-
#call
744-
}};
733+
} else {
734+
(None, None)
735+
};
736+
let args = self_arg.into_iter().chain(args);
737+
let ok_wrap = quotes::ok_wrap(ret_ident.to_token_stream(), ctx);
738+
quote! {
739+
{
740+
let coroutine = {
741+
#slf_py
742+
#output_py
743+
#init_throw_callback
744+
#pyo3_path::impl_::coroutine::new_coroutine(
745+
#pyo3_path::intern!(py, stringify!(#python_name)),
746+
#qualname_prefix,
747+
#throw_callback,
748+
async move {
749+
// SAFETY: attached when future is polled (see `Coroutine::poll`)
750+
let assume_attached = unsafe { #pyo3_path::impl_::coroutine::AssumeAttachedInCoroutine::new() };
751+
#init_holders
752+
let future = {
753+
let py = assume_attached.py();
754+
#slf_ptr
755+
#output_args
756+
function(#(#args),*)
757+
};
758+
let #ret_ident = future.await;
759+
let #ret_ident = #ok_wrap;
760+
#pyo3_path::impl_::wrap::converter(&#ret_ident).map_into_pyobject(assume_attached.py(), #ret_ident)
761+
},
762+
)
763+
};
764+
#pyo3_path::Py::new(py, coroutine).map(#pyo3_path::Py::into_ptr)
765+
}
745766
}
746-
call
747767
} else {
748-
let args = self_arg().into_iter().chain(args);
749-
quote! { function(#(#args),*) }
750-
};
751-
752-
// We must assign the output_span to the return value of the call,
753-
// but *not* of the call itself otherwise the spans get really weird
754-
let ret_ident = Ident::new("ret", *output_span);
755-
let ret_expr = quote! { let #ret_ident = #call; };
756-
let return_conversion =
757-
quotes::map_result_into_ptr(quotes::ok_wrap(ret_ident.to_token_stream(), ctx), ctx);
758-
quote! {
759-
{
760-
#ret_expr
761-
#return_conversion
768+
let args = self_arg.into_iter().chain(args);
769+
let return_conversion = quotes::map_result_into_ptr(
770+
quotes::ok_wrap(ret_ident.to_token_stream(), ctx),
771+
ctx,
772+
);
773+
quote! {
774+
{
775+
#init_holders
776+
let #ret_ident = function(#(#args),*);
777+
#return_conversion
778+
}
762779
}
763780
}
764781
};
@@ -771,10 +788,10 @@ impl<'a> FnSpec<'a> {
771788
};
772789

773790
let warnings = self.warnings.build_py_warning(ctx);
791+
let mut holders = Holders::new();
774792

775793
Ok(match convention {
776794
CallingConvention::Noargs => {
777-
let mut holders = Holders::new();
778795
let args = self
779796
.signature
780797
.arguments
@@ -785,26 +802,22 @@ impl<'a> FnSpec<'a> {
785802
_ => unreachable!("`CallingConvention::Noargs` should not contain any arguments (reaching Python) except for `self`, which is handled below."),
786803
})
787804
.collect();
788-
let call = rust_call(args, &mut holders);
789-
let init_holders = holders.init_holders(ctx);
805+
let call = rust_call(args, holders);
790806
quote! {
791807
unsafe fn #ident<'py>(
792808
py: #pyo3_path::Python<'py>,
793809
_slf: *mut #pyo3_path::ffi::PyObject,
794810
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
795811
let function = #rust_name; // Shadow the function name to avoid #3017
796-
#init_holders
797812
#warnings
798813
let result = #call;
799814
result
800815
}
801816
}
802817
}
803818
CallingConvention::Fastcall => {
804-
let mut holders = Holders::new();
805819
let (arg_convert, args) = impl_arg_params(self, cls, true, &mut holders, ctx);
806-
let call = rust_call(args, &mut holders);
807-
let init_holders = holders.init_holders(ctx);
820+
let call = rust_call(args, holders);
808821

809822
quote! {
810823
unsafe fn #ident<'py>(
@@ -816,18 +829,15 @@ impl<'a> FnSpec<'a> {
816829
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
817830
let function = #rust_name; // Shadow the function name to avoid #3017
818831
#arg_convert
819-
#init_holders
820832
#warnings
821833
let result = #call;
822834
result
823835
}
824836
}
825837
}
826838
CallingConvention::Varargs => {
827-
let mut holders = Holders::new();
828839
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx);
829-
let call = rust_call(args, &mut holders);
830-
let init_holders = holders.init_holders(ctx);
840+
let call = rust_call(args, holders);
831841

832842
quote! {
833843
unsafe fn #ident<'py>(
@@ -838,7 +848,6 @@ impl<'a> FnSpec<'a> {
838848
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
839849
let function = #rust_name; // Shadow the function name to avoid #3017
840850
#arg_convert
841-
#init_holders
842851
#warnings
843852
let result = #call;
844853
result

tests/test_coroutine.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,86 @@ fn test_async_method_receiver_with_other_args() {
369369
py_run!(py, *locals, test);
370370
});
371371
}
372+
373+
#[test]
374+
fn test_async_fn_borrowed_values() {
375+
#[pyclass]
376+
struct Data {
377+
value: String,
378+
}
379+
#[pymethods]
380+
impl Data {
381+
#[new]
382+
fn new(value: String) -> Self {
383+
Self { value }
384+
}
385+
async fn borrow_value(&self) -> &str {
386+
&self.value
387+
}
388+
async fn borrow_value_or_default<'a>(&'a self, default: &'a str) -> &'a str {
389+
if self.value.is_empty() {
390+
default
391+
} else {
392+
&self.value
393+
}
394+
}
395+
}
396+
Python::attach(|py| {
397+
let test = r#"
398+
import asyncio
399+
400+
v = Data('hello')
401+
assert asyncio.run(v.borrow_value()) == 'hello'
402+
assert asyncio.run(v.borrow_value_or_default('')) == 'hello'
403+
404+
v_empty = Data('')
405+
assert asyncio.run(v_empty.borrow_value_or_default('default')) == 'default'
406+
"#;
407+
let locals = [("Data", py.get_type::<Data>())].into_py_dict(py).unwrap();
408+
py_run!(py, *locals, test);
409+
});
410+
}
411+
412+
#[test]
413+
fn test_async_fn_class_values() {
414+
#[pyclass]
415+
struct Value(i32);
416+
417+
#[pymethods]
418+
impl Value {
419+
#[new]
420+
fn new(x: i32) -> Self {
421+
Self(x)
422+
}
423+
424+
#[getter]
425+
fn value(&self) -> i32 {
426+
self.0
427+
}
428+
}
429+
430+
#[pyfunction]
431+
async fn add_two_values(obj: &Value, obj2: &Value) -> Value {
432+
Value(obj.0 + obj2.0)
433+
}
434+
435+
Python::attach(|py| {
436+
let test = r#"
437+
import asyncio
438+
439+
v1 = Value(1)
440+
v2 = Value(2)
441+
assert asyncio.run(add_two_values(v1, v2)).value == 3
442+
"#;
443+
let locals = [
444+
("Value", py.get_type::<Value>().into_any()),
445+
(
446+
"add_two_values",
447+
wrap_pyfunction!(add_two_values, py).unwrap().into_any(),
448+
),
449+
]
450+
.into_py_dict(py)
451+
.unwrap();
452+
py_run!(py, *locals, test);
453+
});
454+
}

0 commit comments

Comments
 (0)