Skip to content

Commit 4a4a376

Browse files
[red-knot] Allow ellipsis default params in stub functions (astral-sh#17243)
## Summary Fixes astral-sh#17234 ## Test Plan Add tests to functions/paremeters.md --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent 5e0f563 commit 4a4a376

File tree

2 files changed

+113
-51
lines changed

2 files changed

+113
-51
lines changed

crates/red_knot_python_semantic/resources/mdtest/function/parameters.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,42 @@ from typing import Any
7373
def g(x: Any = "foo"):
7474
reveal_type(x) # revealed: Any | Literal["foo"]
7575
```
76+
77+
## Stub functions
78+
79+
### In Protocol
80+
81+
```py
82+
from typing import Protocol
83+
84+
class Foo(Protocol):
85+
def x(self, y: bool = ...): ...
86+
def y[T](self, y: T = ...) -> T: ...
87+
88+
class GenericFoo[T](Protocol):
89+
def x(self, y: bool = ...) -> T: ...
90+
```
91+
92+
### In abstract method
93+
94+
```py
95+
from abc import abstractmethod
96+
97+
class Bar:
98+
@abstractmethod
99+
def x(self, y: bool = ...): ...
100+
@abstractmethod
101+
def y[T](self, y: T = ...) -> T: ...
102+
```
103+
104+
### In function overload
105+
106+
```py
107+
from typing import overload
108+
109+
@overload
110+
def x(y: None = ...) -> None: ...
111+
@overload
112+
def x(y: int) -> str: ...
113+
def x(y: int | None = None) -> str | None: ...
114+
```

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 74 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,74 @@ impl<'db> TypeInferenceBuilder<'db> {
11801180
self.infer_annotation_expression(&type_alias.value, DeferredExpressionState::Deferred);
11811181
}
11821182

1183+
/// Returns `true` if the current scope is the function body scope of a method of a protocol
1184+
/// (that is, a class which directly inherits `typing.Protocol`.)
1185+
fn in_class_that_inherits_protocol_directly(&self) -> bool {
1186+
let current_scope_id = self.scope().file_scope_id(self.db());
1187+
let current_scope = self.index.scope(current_scope_id);
1188+
let Some(parent_scope_id) = current_scope.parent() else {
1189+
return false;
1190+
};
1191+
let parent_scope = self.index.scope(parent_scope_id);
1192+
1193+
let class_scope = match parent_scope.kind() {
1194+
ScopeKind::Class => parent_scope,
1195+
ScopeKind::Annotation => {
1196+
let Some(class_scope_id) = parent_scope.parent() else {
1197+
return false;
1198+
};
1199+
let potentially_class_scope = self.index.scope(class_scope_id);
1200+
1201+
match potentially_class_scope.kind() {
1202+
ScopeKind::Class => potentially_class_scope,
1203+
_ => return false,
1204+
}
1205+
}
1206+
_ => return false,
1207+
};
1208+
1209+
let NodeWithScopeKind::Class(node_ref) = class_scope.node() else {
1210+
return false;
1211+
};
1212+
1213+
// TODO move this to `Class` once we add proper `Protocol` support
1214+
node_ref.bases().iter().any(|base| {
1215+
matches!(
1216+
self.file_expression_type(base),
1217+
Type::KnownInstance(KnownInstanceType::Protocol)
1218+
)
1219+
})
1220+
}
1221+
1222+
/// Returns `true` if the current scope is the function body scope of a function overload (that
1223+
/// is, the stub declaration decorated with `@overload`, not the implementation), or an
1224+
/// abstract method (decorated with `@abstractmethod`.)
1225+
fn in_function_overload_or_abstractmethod(&self) -> bool {
1226+
let current_scope_id = self.scope().file_scope_id(self.db());
1227+
let current_scope = self.index.scope(current_scope_id);
1228+
1229+
let function_scope = match current_scope.kind() {
1230+
ScopeKind::Function => current_scope,
1231+
_ => return false,
1232+
};
1233+
1234+
let NodeWithScopeKind::Function(node_ref) = function_scope.node() else {
1235+
return false;
1236+
};
1237+
1238+
node_ref.decorator_list.iter().any(|decorator| {
1239+
let decorator_type = self.file_expression_type(&decorator.expression);
1240+
1241+
match decorator_type {
1242+
Type::FunctionLiteral(function) => matches!(
1243+
function.known(self.db()),
1244+
Some(KnownFunction::Overload | KnownFunction::AbstractMethod)
1245+
),
1246+
_ => false,
1247+
}
1248+
})
1249+
}
1250+
11831251
fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
11841252
// Parameters are odd: they are Definitions in the function body scope, but have no
11851253
// constituent nodes that are part of the function body. In order to get diagnostics
@@ -1210,56 +1278,9 @@ impl<'db> TypeInferenceBuilder<'db> {
12101278
}
12111279
}
12121280

1213-
let is_overload_or_abstract = function.decorator_list.iter().any(|decorator| {
1214-
let decorator_type = self.file_expression_type(&decorator.expression);
1215-
1216-
match decorator_type {
1217-
Type::FunctionLiteral(function) => matches!(
1218-
function.known(self.db()),
1219-
Some(KnownFunction::Overload | KnownFunction::AbstractMethod)
1220-
),
1221-
_ => false,
1222-
}
1223-
});
1224-
1225-
let class_inherits_protocol_directly = (|| -> bool {
1226-
let current_scope_id = self.scope().file_scope_id(self.db());
1227-
let current_scope = self.index.scope(current_scope_id);
1228-
let Some(parent_scope_id) = current_scope.parent() else {
1229-
return false;
1230-
};
1231-
let parent_scope = self.index.scope(parent_scope_id);
1232-
1233-
let class_scope = match parent_scope.kind() {
1234-
ScopeKind::Class => parent_scope,
1235-
ScopeKind::Annotation => {
1236-
let Some(class_scope_id) = parent_scope.parent() else {
1237-
return false;
1238-
};
1239-
let potentially_class_scope = self.index.scope(class_scope_id);
1240-
1241-
match potentially_class_scope.kind() {
1242-
ScopeKind::Class => potentially_class_scope,
1243-
_ => return false,
1244-
}
1245-
}
1246-
_ => return false,
1247-
};
1248-
1249-
let NodeWithScopeKind::Class(node_ref) = class_scope.node() else {
1250-
return false;
1251-
};
1252-
1253-
// TODO move this to `Class` once we add proper `Protocol` support
1254-
node_ref.bases().iter().any(|base| {
1255-
matches!(
1256-
self.file_expression_type(base),
1257-
Type::KnownInstance(KnownInstanceType::Protocol)
1258-
)
1259-
})
1260-
})();
1261-
1262-
if (self.in_stub() || is_overload_or_abstract || class_inherits_protocol_directly)
1281+
if (self.in_stub()
1282+
|| self.in_function_overload_or_abstractmethod()
1283+
|| self.in_class_that_inherits_protocol_directly())
12631284
&& self.return_types_and_ranges.is_empty()
12641285
&& is_stub_suite(&function.body)
12651286
{
@@ -1552,7 +1573,9 @@ impl<'db> TypeInferenceBuilder<'db> {
15521573
declared_ty: declared_ty.into(),
15531574
inferred_ty: UnionType::from_elements(self.db(), [declared_ty, default_ty]),
15541575
}
1555-
} else if self.in_stub()
1576+
} else if (self.in_stub()
1577+
|| self.in_function_overload_or_abstractmethod()
1578+
|| self.in_class_that_inherits_protocol_directly())
15561579
&& default
15571580
.as_ref()
15581581
.is_some_and(|d| d.is_ellipsis_literal_expr())

0 commit comments

Comments
 (0)