Skip to content

Commit 016f795

Browse files
davidhewittTpt
andauthored
Simplify method receivers in macros (#5684)
* simplify `self_arg()` handling * use `PyClassGuard` type instead of custom coroutine guard * simplify logic, adjust test * add `AssumeAttachedInCoroutine` helper * newsfragment * Update pyo3-macros-backend/src/method.rs Co-authored-by: Thomas Tanon <[email protected]> * fixup suggestion conflict --------- Co-authored-by: Thomas Tanon <[email protected]>
1 parent 7c8f8aa commit 016f795

File tree

5 files changed

+94
-153
lines changed

5 files changed

+94
-153
lines changed

newsfragments/5684.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`async` pymethods now borrow `self` only for the duration of awaiting the future, not the entire method call.

pyo3-macros-backend/src/method.rs

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,12 @@ impl FnType {
262262
) -> Option<TokenStream> {
263263
let Ctx { pyo3_path, .. } = ctx;
264264
match self {
265-
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) => {
266-
let mut receiver = st.receiver(
267-
cls.expect("no class given for Fn with a \"self\" receiver"),
268-
error_mode,
269-
holders,
270-
ctx,
271-
);
272-
syn::Token![,](Span::call_site()).to_tokens(&mut receiver);
273-
Some(receiver)
274-
}
265+
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) => Some(st.receiver(
266+
cls.expect("no class given for Fn with a \"self\" receiver"),
267+
error_mode,
268+
holders,
269+
ctx,
270+
)),
275271
FnType::FnClass(span) => {
276272
let py = syn::Ident::new("py", Span::call_site());
277273
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
@@ -283,7 +279,7 @@ impl FnType {
283279
.cast_unchecked::<#pyo3_path::types::PyType>()
284280
)
285281
};
286-
Some(quote! { unsafe { #ret }, })
282+
Some(quote! { unsafe { #ret } })
287283
}
288284
FnType::FnModule(span) => {
289285
let py = syn::Ident::new("py", Span::call_site());
@@ -296,7 +292,7 @@ impl FnType {
296292
.cast_unchecked::<#pyo3_path::types::PyModule>()
297293
)
298294
};
299-
Some(quote! { unsafe { #ret }, })
295+
Some(quote! { unsafe { #ret } })
300296
}
301297
FnType::FnStatic | FnType::ClassAttribute => None,
302298
}
@@ -664,36 +660,45 @@ impl<'a> FnSpec<'a> {
664660
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyClass>::NAME)),
665661
None => quote!(None),
666662
};
667-
let arg_names = (0..args.len())
668-
.map(|i| format_ident!("arg_{}", i))
669-
.collect::<Vec<_>>();
670663
let future = match self.tp {
671-
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
664+
// If extracting `self`, we move the `_slf` pointer into the async block. This reduces the lifetime for which the Rust state is considered "borrowed"
665+
// to just when the async block is executing.
666+
//
667+
// TODO: we should do this with all arguments, not just `self`, e.g. https://github.com/PyO3/pyo3/issues/5681
668+
FnType::Fn(SelfType::Receiver { mutable, .. }) => {
669+
let arg_names = (0..args.len())
670+
.map(|i| format_ident!("arg_{}", i))
671+
.collect::<Vec<_>>();
672+
let method = syn::Ident::new(
673+
if mutable {
674+
"extract_pyclass_ref_mut"
675+
} else {
676+
"extract_pyclass_ref"
677+
},
678+
Span::call_site(),
679+
);
672680
quote! {{
681+
let _slf = unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf) }.to_owned().unbind();
673682
#(let #arg_names = #args;)*
674-
let __guard = unsafe { #pyo3_path::impl_::coroutine::RefGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))? };
675-
async move { function(&__guard, #(#arg_names),*).await }
676-
}}
677-
}
678-
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => {
679-
quote! {{
680-
#(let #arg_names = #args;)*
681-
let mut __guard = unsafe { #pyo3_path::impl_::coroutine::RefMutGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))? };
682-
async move { function(&mut __guard, #(#arg_names),*).await }
683+
async move {
684+
// SAFETY: attached when future is polled (see `Coroutine::poll`)
685+
let assume_attached = unsafe { #pyo3_path::impl_::coroutine::AssumeAttachedInCoroutine::new() };
686+
let py = assume_attached.py();
687+
let mut holder = None;
688+
let future = function(
689+
#pyo3_path::impl_::extract_argument::#method(_slf.bind(py), &mut holder)?,
690+
#(#arg_names),*
691+
);
692+
drop(py);
693+
let result = future.await;
694+
let result: #pyo3_path::PyResult<_> = #pyo3_path::impl_::wrap::converter(&result).wrap(result).map_err(::std::convert::Into::into);
695+
result
696+
}
683697
}}
684698
}
685699
_ => {
686-
if let Some(self_arg) = self_arg() {
687-
quote! {
688-
function(
689-
// NB #self_arg includes a comma, so none inserted here
690-
#self_arg
691-
#(#args),*
692-
)
693-
}
694-
} else {
695-
quote! { function(#(#args),*) }
696-
}
700+
let args = self_arg().into_iter().chain(args);
701+
quote! { function(#(#args),*) }
697702
}
698703
};
699704
let mut call = quote! {{
@@ -703,13 +708,11 @@ impl<'a> FnSpec<'a> {
703708
#qualname_prefix,
704709
#throw_callback,
705710
async move {
706-
let fut = future.await;
707-
let res = #pyo3_path::impl_::wrap::converter(&fut).wrap(fut).map_err(::std::convert::Into::into);
708-
#pyo3_path::impl_::wrap::converter(&res).map_into_pyobject(
709-
// SAFETY: attached when future is polled (see `Coroutine::poll`)
710-
unsafe { #pyo3_path::Python::assume_attached() },
711-
res
712-
)
711+
// SAFETY: attached when future is polled (see `Coroutine::poll`)
712+
let assume_attached = unsafe { #pyo3_path::impl_::coroutine::AssumeAttachedInCoroutine::new() };
713+
let output = future.await;
714+
let res = #pyo3_path::impl_::wrap::converter(&output).wrap(output).map_err(::std::convert::Into::into);
715+
#pyo3_path::impl_::wrap::converter(&res).map_into_pyobject(assume_attached.py(), res)
713716
},
714717
)
715718
}};
@@ -721,15 +724,8 @@ impl<'a> FnSpec<'a> {
721724
}};
722725
}
723726
call
724-
} else if let Some(self_arg) = self_arg() {
725-
quote! {
726-
function(
727-
// NB #self_arg includes a comma, so none inserted here
728-
#self_arg
729-
#(#args),*
730-
)
731-
}
732727
} else {
728+
let args = self_arg().into_iter().chain(args);
733729
quote! { function(#(#args),*) }
734730
};
735731

pyo3-macros-backend/src/pymethod.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,8 @@ fn generate_method_body(
13611361
quote! { *mut #pyo3_path::ffi::PyObject },
13621362
];
13631363
let (arg_convert, args) = impl_arg_params(spec, Some(cls), false, holders, ctx);
1364-
let call = quote_spanned! {*output_span=> #cls::#rust_name(#self_arg #(#args),*) };
1364+
let args = self_arg.into_iter().chain(args);
1365+
let call = quote_spanned! {*output_span=> #cls::#rust_name(#(#args),*) };
13651366

13661367
// Use just the text_signature_call_signature() because the class' Python name
13671368
// isn't known to `#[pymethods]` - that has to be attached at runtime from the PyClassImpl
@@ -1399,8 +1400,9 @@ fn generate_method_body(
13991400
quote! { *mut #pyo3_path::ffi::PyObject },
14001401
];
14011402
let (arg_convert, args) = impl_arg_params(spec, Some(cls), false, holders, ctx);
1403+
let args = self_arg.into_iter().chain(args);
14021404
let call = quote! {{
1403-
let r = #cls::#rust_name(#self_arg #(#args),*);
1405+
let r = #cls::#rust_name(#(#args),*);
14041406
#pyo3_path::impl_::wrap::converter(&r)
14051407
.wrap(r)
14061408
.map_err(::core::convert::Into::<#pyo3_path::PyErr>::into)?
@@ -1425,7 +1427,8 @@ fn generate_method_body(
14251427
.collect();
14261428

14271429
let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders, ctx)?;
1428-
let call = quote! { #cls::#rust_name(#self_arg #(#args),*) };
1430+
let args = self_arg.into_iter().chain(args);
1431+
let call = quote! { #cls::#rust_name(#(#args),*) };
14291432
let result = if let Some(return_mode) = return_mode {
14301433
return_mode.return_call_output(call, ctx)
14311434
} else {

src/impl_/coroutine.rs

Lines changed: 13 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1-
use std::{
2-
future::Future,
3-
ops::{Deref, DerefMut},
4-
};
1+
use std::future::Future;
52

63
use crate::{
74
coroutine::{cancel::ThrowCallback, Coroutine},
85
instance::Bound,
9-
pycell::impl_::PyClassBorrowChecker,
10-
pyclass::boolean_struct::False,
116
types::PyString,
12-
Py, PyAny, PyClass, PyErr, PyResult, Python,
7+
Py, PyAny, PyResult, Python,
138
};
149

1510
pub fn new_coroutine<'py, F>(
@@ -19,78 +14,23 @@ pub fn new_coroutine<'py, F>(
1914
future: F,
2015
) -> Coroutine
2116
where
22-
F: Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static,
17+
F: Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
2318
{
2419
Coroutine::new(Some(name.clone()), qualname_prefix, throw_callback, future)
2520
}
2621

27-
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
28-
obj.get_class_object().get_ptr()
29-
}
30-
31-
pub struct RefGuard<T: PyClass>(Py<T>);
32-
33-
impl<T: PyClass> RefGuard<T> {
34-
pub fn new(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
35-
let bound = obj.cast::<T>()?;
36-
bound.get_class_object().borrow_checker().try_borrow()?;
37-
Ok(RefGuard(bound.clone().unbind()))
38-
}
39-
}
40-
41-
impl<T: PyClass> Deref for RefGuard<T> {
42-
type Target = T;
43-
fn deref(&self) -> &Self::Target {
44-
// SAFETY: `RefGuard` has been built from `PyRef` and provides the same guarantees
45-
unsafe { &*get_ptr(&self.0) }
46-
}
47-
}
48-
49-
impl<T: PyClass> Drop for RefGuard<T> {
50-
fn drop(&mut self) {
51-
Python::attach(|py| {
52-
self.0
53-
.bind(py)
54-
.get_class_object()
55-
.borrow_checker()
56-
.release_borrow()
57-
})
58-
}
59-
}
60-
61-
pub struct RefMutGuard<T: PyClass<Frozen = False>>(Py<T>);
62-
63-
impl<T: PyClass<Frozen = False>> RefMutGuard<T> {
64-
pub fn new(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
65-
let bound = obj.cast::<T>()?;
66-
bound.get_class_object().borrow_checker().try_borrow_mut()?;
67-
Ok(RefMutGuard(bound.clone().unbind()))
68-
}
69-
}
22+
/// Handle which assumes that the coroutine is attached to the thread. Unlike `Python<'_>`, this is `Send`.
23+
pub struct AssumeAttachedInCoroutine(());
7024

71-
impl<T: PyClass<Frozen = False>> Deref for RefMutGuard<T> {
72-
type Target = T;
73-
fn deref(&self) -> &Self::Target {
74-
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
75-
unsafe { &*get_ptr(&self.0) }
25+
impl AssumeAttachedInCoroutine {
26+
/// Safety: this should only be used inside a future passed to `new_coroutine`, where the coroutine is
27+
/// guaranteed to be attached to the thread when polled.
28+
pub unsafe fn new() -> Self {
29+
Self(())
7630
}
77-
}
78-
79-
impl<T: PyClass<Frozen = False>> DerefMut for RefMutGuard<T> {
80-
fn deref_mut(&mut self) -> &mut Self::Target {
81-
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
82-
unsafe { &mut *get_ptr(&self.0) }
83-
}
84-
}
8531

86-
impl<T: PyClass<Frozen = False>> Drop for RefMutGuard<T> {
87-
fn drop(&mut self) {
88-
Python::attach(|py| {
89-
self.0
90-
.bind(py)
91-
.get_class_object()
92-
.borrow_checker()
93-
.release_borrow_mut()
94-
})
32+
pub fn py(&self) -> Python<'_> {
33+
// Safety: this type holds the invariant that the thread is attached
34+
unsafe { Python::assume_attached() }
9535
}
9636
}

tests/test_coroutine.rs

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,18 @@ fn test_async_method_receiver() {
278278
fn new() -> Self {
279279
Self(0)
280280
}
281-
async fn get(&self) -> usize {
281+
async fn get(&self, resolve: bool) -> usize {
282+
if !resolve {
283+
// hang the future to test borrow checking
284+
std::future::pending().await
285+
}
282286
self.0
283287
}
284-
async fn incr(&mut self) -> usize {
288+
async fn incr(&mut self, resolve: bool) -> usize {
289+
if !resolve {
290+
// hang the future to test borrow checking
291+
std::future::pending().await
292+
}
285293
self.0 += 1;
286294
self.0
287295
}
@@ -300,30 +308,23 @@ fn test_async_method_receiver() {
300308
import asyncio
301309
302310
obj = Counter()
303-
coro1 = obj.get()
304-
coro2 = obj.get()
305-
try:
306-
obj.incr() # borrow checking should fail
307-
except RuntimeError as err:
308-
pass
309-
else:
310-
assert False
311-
assert asyncio.run(coro1) == 0
312-
coro2.close()
313-
coro3 = obj.incr()
314-
try:
315-
obj.incr() # borrow checking should fail
316-
except RuntimeError as err:
317-
pass
318-
else:
319-
assert False
320-
try:
321-
obj.get() # borrow checking should fail
322-
except RuntimeError as err:
323-
pass
324-
else:
325-
assert False
326-
assert asyncio.run(coro3) == 1
311+
312+
assert asyncio.run(obj.get(True)) == 0
313+
assert asyncio.run(obj.incr(True)) == 1
314+
315+
for left in [obj.get, obj.incr]:
316+
for right in [obj.get, obj.incr]:
317+
# first future will not resolve to hold the borrow
318+
coro1 = left(False)
319+
coro2 = right(True)
320+
try:
321+
asyncio.run(asyncio.gather(coro1, coro2))
322+
except RuntimeError as err:
323+
ran = False
324+
else:
325+
ran = True
326+
if left is obj.incr or right is obj.incr:
327+
assert not ran, "mutable method calls should not run concurrently with other method calls"
327328
"#;
328329
let locals = [("Counter", py.get_type::<Counter>())]
329330
.into_py_dict(py)

0 commit comments

Comments
 (0)