Skip to content

Commit e8e5a50

Browse files
authored
clean up handling of default values in function signatures (#5677)
* clean up handling of default values in function signatures * fixups * ensure trailing optional argument required
1 parent 6af9595 commit e8e5a50

File tree

6 files changed

+106
-111
lines changed

6 files changed

+106
-111
lines changed

pyo3-macros-backend/src/introspection.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
use crate::method::{FnArg, RegularArg};
1212
use crate::pyfunction::FunctionSignature;
1313
use crate::type_hint::PythonTypeHint;
14-
use crate::utils::PyO3CratePath;
14+
use crate::utils::{expr_to_python, PyO3CratePath};
1515
use proc_macro2::{Span, TokenStream};
1616
use quote::{format_ident, quote, ToTokens};
1717
use std::borrow::Cow;
@@ -291,10 +291,10 @@ fn argument_introspection_data<'a>(
291291
class_type: Option<&Type>,
292292
) -> AttributedIntrospectionNode<'a> {
293293
let mut params: HashMap<_, _> = [("name", IntrospectionNode::String(name.into()))].into();
294-
if desc.default_value.is_some() {
294+
if let Some(expr) = &desc.default_value {
295295
params.insert(
296296
"default",
297-
IntrospectionNode::String(desc.default_value().into()),
297+
IntrospectionNode::String(expr_to_python(expr).into()),
298298
);
299299
}
300300

pyo3-macros-backend/src/method.rs

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
99

1010
use crate::pyfunction::{PyFunctionWarning, WarningFactory};
1111
use crate::pyversions::is_abi3_before;
12-
use crate::utils::{expr_to_python, Ctx};
12+
use crate::utils::Ctx;
1313
use crate::{
1414
attributes::{FromPyWithAttribute, TextSignatureAttribute, TextSignatureAttributeValue},
1515
params::{impl_arg_params, Holders},
@@ -31,28 +31,6 @@ pub struct RegularArg<'a> {
3131
pub annotation: Option<String>,
3232
}
3333

34-
impl RegularArg<'_> {
35-
pub fn default_value(&self) -> String {
36-
if let Self {
37-
default_value: Some(arg_default),
38-
..
39-
} = self
40-
{
41-
expr_to_python(arg_default)
42-
} else if let RegularArg {
43-
option_wrapped_type: Some(..),
44-
..
45-
} = self
46-
{
47-
// functions without a `#[pyo3(signature = (...))]` option
48-
// will treat trailing `Option<T>` arguments as having a default of `None`
49-
"None".to_string()
50-
} else {
51-
"...".to_string()
52-
}
53-
}
54-
}
55-
5634
/// Pythons *args argument
5735
#[derive(Clone, Debug)]
5836
pub struct VarargsArg<'a> {
@@ -220,14 +198,6 @@ impl<'a> FnArg<'a> {
220198
}
221199
}
222200
}
223-
224-
pub fn default_value(&self) -> String {
225-
if let Self::Regular(args) = self {
226-
args.default_value()
227-
} else {
228-
"...".to_string()
229-
}
230-
}
231201
}
232202

233203
fn handle_argument_error(pat: &syn::Pat) -> syn::Error {

pyo3-macros-backend/src/params.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,17 @@ pub fn impl_arg_params(
9090

9191
let positional_parameter_names = &spec.signature.python_signature.positional_parameters;
9292
let positional_only_parameters = &spec.signature.python_signature.positional_only_parameters;
93-
let required_positional_parameters = &spec
93+
let required_positional_parameters = spec
9494
.signature
9595
.python_signature
96-
.required_positional_parameters;
96+
.required_positional_parameters();
9797
let keyword_only_parameters = spec
9898
.signature
9999
.python_signature
100100
.keyword_only_parameters
101101
.iter()
102-
.map(|(name, required)| {
102+
.map(|(name, default_value)| {
103+
let required = default_value.is_none();
103104
quote! {
104105
#pyo3_path::impl_::extract_argument::KeywordOnlyParameterDescription {
105106
name: #name,

pyo3-macros-backend/src/pyfunction/signature.rs

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::{
22
attributes::{kw, KeywordAttribute},
3-
method::{FnArg, RegularArg},
3+
method::FnArg,
4+
utils::expr_to_python,
45
};
56
use proc_macro2::{Span, TokenStream};
67
use quote::ToTokens;
@@ -9,7 +10,7 @@ use syn::{
910
parse::{Parse, ParseStream},
1011
punctuated::Punctuated,
1112
spanned::Spanned,
12-
Token,
13+
Expr, Token,
1314
};
1415

1516
#[derive(Clone)]
@@ -270,10 +271,11 @@ impl ConstructorAttribute {
270271
pub struct PythonSignature {
271272
pub positional_parameters: Vec<String>,
272273
pub positional_only_parameters: usize,
273-
pub required_positional_parameters: usize,
274+
/// Vector of expressions representing positional defaults
275+
pub default_positional_parameters: Vec<Expr>,
274276
pub varargs: Option<String>,
275-
// Tuples of keyword name and whether it is required
276-
pub keyword_only_parameters: Vec<(String, bool)>,
277+
// Tuples of keyword name and optional default value
278+
pub keyword_only_parameters: Vec<(String, Option<Expr>)>,
277279
pub kwargs: Option<String>,
278280
}
279281

@@ -284,6 +286,13 @@ impl PythonSignature {
284286
&& self.varargs.is_none()
285287
&& self.kwargs.is_none()
286288
}
289+
290+
pub fn required_positional_parameters(&self) -> usize {
291+
self.positional_parameters
292+
.len()
293+
.checked_sub(self.default_positional_parameters.len())
294+
.expect("should always have positional defaults <= positional parameters")
295+
}
287296
}
288297

289298
#[derive(Clone)]
@@ -309,23 +318,24 @@ impl ParseState {
309318
&mut self,
310319
signature: &mut PythonSignature,
311320
name: String,
312-
required: bool,
321+
default_value: Option<Expr>,
313322
span: Span,
314323
) -> syn::Result<()> {
315324
match self {
316325
ParseState::Positional | ParseState::PositionalAfterPosargs => {
317326
signature.positional_parameters.push(name);
318-
if required {
319-
signature.required_positional_parameters += 1;
320-
ensure_spanned!(
321-
signature.required_positional_parameters == signature.positional_parameters.len(),
322-
span => "cannot have required positional parameter after an optional parameter"
323-
);
327+
if let Some(default_value) = default_value {
328+
signature.default_positional_parameters.push(default_value);
329+
// Now all subsequent positional parameters must also have defaults
330+
} else if !signature.default_positional_parameters.is_empty() {
331+
bail_spanned!(span => "cannot have required positional parameter after an optional parameter")
324332
}
325333
Ok(())
326334
}
327335
ParseState::Keywords => {
328-
signature.keyword_only_parameters.push((name, required));
336+
signature
337+
.keyword_only_parameters
338+
.push((name, default_value));
329339
Ok(())
330340
}
331341
ParseState::Done => {
@@ -475,7 +485,9 @@ impl<'a> FunctionSignature<'a> {
475485
parse_state.add_argument(
476486
&mut python_signature,
477487
arg.ident.unraw().to_string(),
478-
arg.eq_and_default.is_none(),
488+
arg.eq_and_default
489+
.as_ref()
490+
.map(|(_, default)| default.clone()),
479491
arg.span(),
480492
)?;
481493
let FnArg::Regular(fn_arg) = fn_arg else {
@@ -577,17 +589,6 @@ impl<'a> FunctionSignature<'a> {
577589
continue;
578590
}
579591

580-
if let FnArg::Regular(RegularArg { .. }) = arg {
581-
// This argument is required, all previous arguments must also have been required
582-
assert_eq!(
583-
python_signature.required_positional_parameters,
584-
python_signature.positional_parameters.len(),
585-
);
586-
587-
python_signature.required_positional_parameters =
588-
python_signature.positional_parameters.len() + 1;
589-
}
590-
591592
python_signature
592593
.positional_parameters
593594
.push(arg.name().unraw().to_string());
@@ -600,14 +601,6 @@ impl<'a> FunctionSignature<'a> {
600601
}
601602
}
602603

603-
fn default_value_for_parameter(&self, parameter: &str) -> String {
604-
if let Some(fn_arg) = self.arguments.iter().find(|arg| arg.name() == parameter) {
605-
fn_arg.default_value()
606-
} else {
607-
"...".to_string()
608-
}
609-
}
610-
611604
pub fn text_signature(&self, self_argument: Option<&str>) -> String {
612605
let mut output = String::new();
613606
output.push('(');
@@ -630,14 +623,19 @@ impl<'a> FunctionSignature<'a> {
630623

631624
let py_sig = &self.python_signature;
632625

633-
for (i, parameter) in py_sig.positional_parameters.iter().enumerate() {
626+
let defaults = std::iter::repeat_n(None, py_sig.required_positional_parameters())
627+
.chain(py_sig.default_positional_parameters.iter().map(Some));
628+
629+
for (i, (parameter, default)) in
630+
std::iter::zip(&py_sig.positional_parameters, defaults).enumerate()
631+
{
634632
maybe_push_comma(&mut output);
635633

636634
output.push_str(parameter);
637635

638-
if i >= py_sig.required_positional_parameters {
636+
if let Some(expr) = default {
639637
output.push('=');
640-
output.push_str(&self.default_value_for_parameter(parameter));
638+
output.push_str(&expr_to_python(expr));
641639
}
642640

643641
if py_sig.positional_only_parameters > 0 && i + 1 == py_sig.positional_only_parameters {
@@ -654,12 +652,12 @@ impl<'a> FunctionSignature<'a> {
654652
output.push('*');
655653
}
656654

657-
for (parameter, required) in &py_sig.keyword_only_parameters {
655+
for (parameter, default) in &py_sig.keyword_only_parameters {
658656
maybe_push_comma(&mut output);
659657
output.push_str(parameter);
660-
if !required {
658+
if let Some(expr) = default {
661659
output.push('=');
662-
output.push_str(&self.default_value_for_parameter(parameter));
660+
output.push_str(&expr_to_python(expr));
663661
}
664662
}
665663

tests/test_pyfunction.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,32 @@ fn test_optional_bool() {
4848
});
4949
}
5050

51+
#[test]
52+
fn test_trailing_optional_no_signature() {
53+
// Since PyO3 0.24, trailing optional arguments are treated like any other required argument
54+
// (previously would get an implicit default of `None`)
55+
56+
#[pyfunction]
57+
fn trailing_optional(x: i32, y: Option<i32>) -> String {
58+
format!("x={x:?} y={y:?}")
59+
}
60+
61+
Python::attach(|py| {
62+
let f = wrap_pyfunction!(trailing_optional)(py).unwrap();
63+
64+
py_assert!(py, f, "f(1, 2) == 'x=1 y=Some(2)'");
65+
py_assert!(py, f, "f(2, None) == 'x=2 y=None'");
66+
67+
py_expect_exception!(
68+
py,
69+
f,
70+
"f(3)",
71+
PyTypeError,
72+
"trailing_optional() missing 1 required positional argument: 'y'"
73+
);
74+
});
75+
}
76+
5177
#[pyfunction]
5278
#[pyo3(signature=(arg))]
5379
fn required_optional_str(arg: Option<&str>) -> &str {
@@ -582,7 +608,10 @@ fn test_return_value_borrows_from_arguments() {
582608

583609
#[test]
584610
fn test_some_wrap_arguments() {
585-
// https://github.com/PyO3/pyo3/issues/3460
611+
// Option<T> arguments get special treatment in pyfunction default values where it's
612+
// valid to pass the inner type without wrapping in `Some()`.
613+
//
614+
// See also https://github.com/PyO3/pyo3/issues/3460
586615
const NONE: Option<u8> = None;
587616
#[pyfunction(signature = (a = 1, b = Some(2), c = None, d = NONE))]
588617
fn some_wrap_arguments(

tests/test_text_signature.rs

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -146,48 +146,45 @@ fn test_auto_test_signature_function() {
146146
let _ = (a, b, c);
147147
}
148148

149+
#[pyfunction]
150+
fn trailing_optional_required(a: i32, b: Option<i32>, c: Option<i32>) {
151+
// Since PyO3 0.24, trailing optional arguments are treated like any other required argument
152+
// (previously would get an implicit default of `None`)
153+
let _ = (a, b, c);
154+
}
155+
156+
macro_rules! assert_text_signature {
157+
($py:expr, $func:ident, $expected:expr) => {
158+
assert_eq!(
159+
wrap_pyfunction!($func, $py)
160+
.unwrap()
161+
.getattr("__text_signature__")
162+
.unwrap()
163+
.cast_into::<pyo3::types::PyString>()
164+
.unwrap(),
165+
$expected
166+
);
167+
};
168+
}
169+
149170
Python::attach(|py| {
150-
let f = wrap_pyfunction!(my_function)(py).unwrap();
151-
py_assert!(
152-
py,
153-
f,
154-
"f.__text_signature__ == '(a, b, c)', f.__text_signature__"
155-
);
171+
assert_text_signature!(py, my_function, "(a, b, c)");
156172

157-
let f = wrap_pyfunction!(my_function_2)(py).unwrap();
158-
py_assert!(
159-
py,
160-
f,
161-
"f.__text_signature__ == '($module, a, b, c)', f.__text_signature__"
162-
);
173+
assert_text_signature!(py, my_function_2, "($module, a, b, c)");
163174

164-
let f = wrap_pyfunction!(my_function_3)(py).unwrap();
165-
py_assert!(
166-
py,
167-
f,
168-
"f.__text_signature__ == '(a, /, b=None, *, c=5)', f.__text_signature__"
169-
);
175+
assert_text_signature!(py, my_function_3, "(a, /, b=None, *, c=5)");
170176

171-
let f = wrap_pyfunction!(my_function_4)(py).unwrap();
172-
py_assert!(
173-
py,
174-
f,
175-
"f.__text_signature__ == '(a, /, b=None, *args, c, d=5, **kwargs)', f.__text_signature__"
176-
);
177+
assert_text_signature!(py, my_function_4, "(a, /, b=None, *args, c, d=5, **kwargs)");
177178

178-
let f = wrap_pyfunction!(my_function_5)(py).unwrap();
179-
py_assert!(
179+
assert_text_signature!(
180180
py,
181-
f,
182-
"f.__text_signature__ == '(a=1, /, b=None, c=1.5, d=5, e=\"pyo3\", f=\\'f\\', h=True)', f.__text_signature__"
181+
my_function_5,
182+
"(a=1, /, b=None, c=1.5, d=5, e=\"pyo3\", f='f', h=True)"
183183
);
184184

185-
let f = wrap_pyfunction!(my_function_6)(py).unwrap();
186-
py_assert!(
187-
py,
188-
f,
189-
"f.__text_signature__ == '(a, b=None, c=None)', f.__text_signature__"
190-
);
185+
assert_text_signature!(py, my_function_6, "(a, b=None, c=None)");
186+
187+
assert_text_signature!(py, trailing_optional_required, "(a, b, c)");
191188
});
192189
}
193190

0 commit comments

Comments
 (0)