@@ -1180,6 +1180,74 @@ impl<'db> TypeInferenceBuilder<'db> {
11801180 self . infer_annotation_expression ( & type_alias. value , DeferredExpressionState :: Deferred ) ;
11811181 }
11821182
1183+ /// Returns `true` if the current scope is the function body scope of a method of a protocol
1184+ /// (that is, a class which directly inherits `typing.Protocol`.)
1185+ fn in_class_that_inherits_protocol_directly ( & self ) -> bool {
1186+ let current_scope_id = self . scope ( ) . file_scope_id ( self . db ( ) ) ;
1187+ let current_scope = self . index . scope ( current_scope_id) ;
1188+ let Some ( parent_scope_id) = current_scope. parent ( ) else {
1189+ return false ;
1190+ } ;
1191+ let parent_scope = self . index . scope ( parent_scope_id) ;
1192+
1193+ let class_scope = match parent_scope. kind ( ) {
1194+ ScopeKind :: Class => parent_scope,
1195+ ScopeKind :: Annotation => {
1196+ let Some ( class_scope_id) = parent_scope. parent ( ) else {
1197+ return false ;
1198+ } ;
1199+ let potentially_class_scope = self . index . scope ( class_scope_id) ;
1200+
1201+ match potentially_class_scope. kind ( ) {
1202+ ScopeKind :: Class => potentially_class_scope,
1203+ _ => return false ,
1204+ }
1205+ }
1206+ _ => return false ,
1207+ } ;
1208+
1209+ let NodeWithScopeKind :: Class ( node_ref) = class_scope. node ( ) else {
1210+ return false ;
1211+ } ;
1212+
1213+ // TODO move this to `Class` once we add proper `Protocol` support
1214+ node_ref. bases ( ) . iter ( ) . any ( |base| {
1215+ matches ! (
1216+ self . file_expression_type( base) ,
1217+ Type :: KnownInstance ( KnownInstanceType :: Protocol )
1218+ )
1219+ } )
1220+ }
1221+
1222+ /// Returns `true` if the current scope is the function body scope of a function overload (that
1223+ /// is, the stub declaration decorated with `@overload`, not the implementation), or an
1224+ /// abstract method (decorated with `@abstractmethod`.)
1225+ fn in_function_overload_or_abstractmethod ( & self ) -> bool {
1226+ let current_scope_id = self . scope ( ) . file_scope_id ( self . db ( ) ) ;
1227+ let current_scope = self . index . scope ( current_scope_id) ;
1228+
1229+ let function_scope = match current_scope. kind ( ) {
1230+ ScopeKind :: Function => current_scope,
1231+ _ => return false ,
1232+ } ;
1233+
1234+ let NodeWithScopeKind :: Function ( node_ref) = function_scope. node ( ) else {
1235+ return false ;
1236+ } ;
1237+
1238+ node_ref. decorator_list . iter ( ) . any ( |decorator| {
1239+ let decorator_type = self . file_expression_type ( & decorator. expression ) ;
1240+
1241+ match decorator_type {
1242+ Type :: FunctionLiteral ( function) => matches ! (
1243+ function. known( self . db( ) ) ,
1244+ Some ( KnownFunction :: Overload | KnownFunction :: AbstractMethod )
1245+ ) ,
1246+ _ => false ,
1247+ }
1248+ } )
1249+ }
1250+
11831251 fn infer_function_body ( & mut self , function : & ast:: StmtFunctionDef ) {
11841252 // Parameters are odd: they are Definitions in the function body scope, but have no
11851253 // constituent nodes that are part of the function body. In order to get diagnostics
@@ -1210,56 +1278,9 @@ impl<'db> TypeInferenceBuilder<'db> {
12101278 }
12111279 }
12121280
1213- let is_overload_or_abstract = function. decorator_list . iter ( ) . any ( |decorator| {
1214- let decorator_type = self . file_expression_type ( & decorator. expression ) ;
1215-
1216- match decorator_type {
1217- Type :: FunctionLiteral ( function) => matches ! (
1218- function. known( self . db( ) ) ,
1219- Some ( KnownFunction :: Overload | KnownFunction :: AbstractMethod )
1220- ) ,
1221- _ => false ,
1222- }
1223- } ) ;
1224-
1225- let class_inherits_protocol_directly = ( || -> bool {
1226- let current_scope_id = self . scope ( ) . file_scope_id ( self . db ( ) ) ;
1227- let current_scope = self . index . scope ( current_scope_id) ;
1228- let Some ( parent_scope_id) = current_scope. parent ( ) else {
1229- return false ;
1230- } ;
1231- let parent_scope = self . index . scope ( parent_scope_id) ;
1232-
1233- let class_scope = match parent_scope. kind ( ) {
1234- ScopeKind :: Class => parent_scope,
1235- ScopeKind :: Annotation => {
1236- let Some ( class_scope_id) = parent_scope. parent ( ) else {
1237- return false ;
1238- } ;
1239- let potentially_class_scope = self . index . scope ( class_scope_id) ;
1240-
1241- match potentially_class_scope. kind ( ) {
1242- ScopeKind :: Class => potentially_class_scope,
1243- _ => return false ,
1244- }
1245- }
1246- _ => return false ,
1247- } ;
1248-
1249- let NodeWithScopeKind :: Class ( node_ref) = class_scope. node ( ) else {
1250- return false ;
1251- } ;
1252-
1253- // TODO move this to `Class` once we add proper `Protocol` support
1254- node_ref. bases ( ) . iter ( ) . any ( |base| {
1255- matches ! (
1256- self . file_expression_type( base) ,
1257- Type :: KnownInstance ( KnownInstanceType :: Protocol )
1258- )
1259- } )
1260- } ) ( ) ;
1261-
1262- if ( self . in_stub ( ) || is_overload_or_abstract || class_inherits_protocol_directly)
1281+ if ( self . in_stub ( )
1282+ || self . in_function_overload_or_abstractmethod ( )
1283+ || self . in_class_that_inherits_protocol_directly ( ) )
12631284 && self . return_types_and_ranges . is_empty ( )
12641285 && is_stub_suite ( & function. body )
12651286 {
@@ -1552,7 +1573,9 @@ impl<'db> TypeInferenceBuilder<'db> {
15521573 declared_ty : declared_ty. into ( ) ,
15531574 inferred_ty : UnionType :: from_elements ( self . db ( ) , [ declared_ty, default_ty] ) ,
15541575 }
1555- } else if self . in_stub ( )
1576+ } else if ( self . in_stub ( )
1577+ || self . in_function_overload_or_abstractmethod ( )
1578+ || self . in_class_that_inherits_protocol_directly ( ) )
15561579 && default
15571580 . as_ref ( )
15581581 . is_some_and ( |d| d. is_ellipsis_literal_expr ( ) )
0 commit comments