@@ -202,7 +202,15 @@ def prepare_func_def(
202202 else (FUNC_STATICMETHOD if fdef .is_static else FUNC_NORMAL )
203203 )
204204 sig = mapper .fdef_to_sig (fdef , options .strict_dunders_typing )
205- decl = FuncDecl (fdef .name , class_name , module_name , sig , kind )
205+ decl = FuncDecl (
206+ fdef .name ,
207+ class_name ,
208+ module_name ,
209+ sig ,
210+ kind ,
211+ is_generator = fdef .is_generator ,
212+ is_coroutine = fdef .is_coroutine ,
213+ )
206214 mapper .func_to_decl [fdef ] = decl
207215 return decl
208216
@@ -217,7 +225,7 @@ def create_generator_class_for_func(
217225 """
218226 assert fdef .is_coroutine or fdef .is_generator
219227 name = "_" .join (x for x in [fdef .name , class_name ] if x ) + "_gen" + name_suffix
220- cir = ClassIR (name , module_name , is_generated = True , is_final_class = True )
228+ cir = ClassIR (name , module_name , is_generated = True , is_final_class = class_name is None )
221229 cir .reuse_freed_instance = True
222230 mapper .fdef_to_generator [fdef ] = cir
223231
@@ -816,14 +824,70 @@ def adjust_generator_classes_of_methods(mapper: Mapper) -> None:
816824 This is a separate pass after type map has been built, since we need all classes
817825 to be processed to analyze class hierarchies.
818826 """
819- for fdef , ir in mapper .func_to_decl .items ():
827+
828+ generator_methods = []
829+
830+ for fdef , fn_ir in mapper .func_to_decl .items ():
820831 if isinstance (fdef , FuncDef ) and (fdef .is_coroutine or fdef .is_generator ):
821- gen_ir = create_generator_class_for_func (ir .module_name , ir .class_name , fdef , mapper )
832+ gen_ir = create_generator_class_for_func (
833+ fn_ir .module_name , fn_ir .class_name , fdef , mapper
834+ )
822835 # TODO: We could probably support decorators sometimes (static and class method?)
823836 if not fdef .is_decorated :
824- # Give a more precise type for generators, so that we can optimize
825- # code that uses them. They return a generator object, which has a
826- # specific class. Without this, the type would have to be 'object'.
827- ir .sig .ret_type = RInstance (gen_ir )
828- if ir .bound_sig :
829- ir .bound_sig .ret_type = RInstance (gen_ir )
837+ name = fn_ir .name
838+ precise_ret_type = True
839+ if fn_ir .class_name is not None :
840+ class_ir = mapper .type_to_ir [fdef .info ]
841+ subcls = class_ir .subclasses ()
842+ if subcls is None :
843+ # Override could be of a different type, so we can't make assumptions.
844+ precise_ret_type = False
845+ else :
846+ for s in subcls :
847+ if name in s .method_decls :
848+ m = s .method_decls [name ]
849+ if (
850+ m .is_generator != fn_ir .is_generator
851+ or m .is_coroutine != fn_ir .is_coroutine
852+ ):
853+ # Override is of a different kind, and the optimization
854+ # to use a precise generator return type doesn't work.
855+ precise_ret_type = False
856+ else :
857+ class_ir = None
858+
859+ if precise_ret_type :
860+ # Give a more precise type for generators, so that we can optimize
861+ # code that uses them. They return a generator object, which has a
862+ # specific class. Without this, the type would have to be 'object'.
863+ fn_ir .sig .ret_type = RInstance (gen_ir )
864+ if fn_ir .bound_sig :
865+ fn_ir .bound_sig .ret_type = RInstance (gen_ir )
866+ if class_ir is not None :
867+ if class_ir .is_method_final (name ):
868+ gen_ir .is_final_class = True
869+ generator_methods .append ((name , class_ir , gen_ir ))
870+
871+ new_bases = {}
872+
873+ for name , class_ir , gen in generator_methods :
874+ # For generator methods, we need to have subclass generator classes inherit from
875+ # baseclass generator classes when there are overrides to maintain LSP.
876+ base = class_ir .real_base ()
877+ if base is not None :
878+ if base .has_method (name ):
879+ base_sig = base .method_sig (name )
880+ if isinstance (base_sig .ret_type , RInstance ):
881+ base_gen = base_sig .ret_type .class_ir
882+ new_bases [gen ] = base_gen
883+
884+ # Add generator inheritance relationships by adjusting MROs.
885+ for deriv , base in new_bases .items ():
886+ if base .children is not None :
887+ base .children .append (deriv )
888+ while True :
889+ deriv .mro .append (base )
890+ deriv .base_mro .append (base )
891+ if base not in new_bases :
892+ break
893+ base = new_bases [base ]
0 commit comments