@@ -449,6 +449,39 @@ def _unwrap_descriptor(dobj):
449449 return getattr (obj , '__get__' , obj )
450450
451451
452+ def _unwrap_object (obj : T ) -> T :
453+ """
454+ This is a modified version of `inspect.unwrap()` that properly handles classes.
455+
456+ Follows the chains of `__wrapped__` attributes, until either:
457+ 1. `obj.__wrapped__` is missing or None
458+ 2. `obj` is a class and `obj.__wrapped__` has a different name or module
459+ """
460+
461+ orig = obj # remember the original func for error reporting
462+ # Memoise by id to tolerate non-hashable objects, but store objects to
463+ # ensure they aren't destroyed, which would allow their IDs to be reused.
464+ memo = {id (orig ): orig }
465+ recursion_limit = sys .getrecursionlimit ()
466+ while hasattr (obj , '__wrapped__' ):
467+ candidate = obj .__wrapped__
468+ if candidate is None :
469+ break
470+
471+ if isinstance (candidate , type ) and isinstance (orig , type ):
472+ if not (candidate .__name__ == orig .__name__
473+ and candidate .__module__ == orig .__module__ ):
474+ break
475+
476+ obj = typing .cast (T , candidate )
477+ id_func = id (obj )
478+ if (id_func in memo ) or (len (memo ) >= recursion_limit ):
479+ raise ValueError ('wrapper loop when unwrapping {!r}' .format (orig ))
480+ memo [id_func ] = obj
481+
482+ return obj
483+
484+
452485def _filter_type (type : Type [T ],
453486 values : Union [Iterable ['Doc' ], Mapping [str , 'Doc' ]]) -> List [T ]:
454487 """
@@ -712,11 +745,11 @@ def __init__(self, module: Union[ModuleType, str], *,
712745 "exported in `__all__`" )
713746 else :
714747 if not _is_blacklisted (name , self ):
715- obj = inspect . unwrap (obj )
748+ obj = _unwrap_object (obj )
716749 public_objs .append ((name , obj ))
717750 else :
718751 def is_from_this_module (obj ):
719- mod = inspect .getmodule (inspect . unwrap (obj ))
752+ mod = inspect .getmodule (_unwrap_object (obj ))
720753 return mod is None or mod .__name__ == self .obj .__name__
721754
722755 for name , obj in inspect .getmembers (self .obj ):
@@ -730,7 +763,7 @@ def is_from_this_module(obj):
730763 self ._context .blacklisted .add (f'{ self .refname } .{ name } ' )
731764 continue
732765
733- obj = inspect . unwrap (obj )
766+ obj = _unwrap_object (obj )
734767 public_objs .append ((name , obj ))
735768
736769 index = list (self .obj .__dict__ ).index
@@ -1066,7 +1099,7 @@ def __init__(self, name: str, module: Module, obj, *, docstring: Optional[str] =
10661099 self .module ._context .blacklisted .add (f'{ self .refname } .{ _name } ' )
10671100 continue
10681101
1069- obj = inspect . unwrap (obj )
1102+ obj = _unwrap_object (obj )
10701103 public_objs .append ((_name , obj ))
10711104
10721105 def definition_order_index (
@@ -1428,7 +1461,7 @@ def _is_async(self):
14281461 try :
14291462 # Both of these are required because coroutines aren't classified as async
14301463 # generators and vice versa.
1431- obj = inspect . unwrap (self .obj )
1464+ obj = _unwrap_object (self .obj )
14321465 return (inspect .iscoroutinefunction (obj ) or
14331466 inspect .isasyncgenfunction (obj ))
14341467 except AttributeError :
0 commit comments