@@ -449,6 +449,43 @@ def _unwrap_descriptor(dobj):
449449 return getattr (obj , '__get__' , obj )
450450
451451
452+ def _unwrap_object (obj : T , * , stop : Optional [Callable [[T ], bool ]] = None ) -> T :
453+ """
454+ This is a modified version of `inspect.unwrap()` that properly handles classes.
455+
456+ Follows the chains of `__wrapped__` attributes, until either:
457+ 1. `obj.__wrapped__` is missing or None
458+ 2. `obj` is a class and `obj.__wrapped__` has a different name or module
459+ 3. `stop` is given and `stop(obj)` is True
460+ """
461+
462+ orig = obj # remember the original func for error reporting
463+ # Memoise by id to tolerate non-hashable objects, but store objects to
464+ # ensure they aren't destroyed, which would allow their IDs to be reused.
465+ memo = {id (orig ): orig }
466+ recursion_limit = sys .getrecursionlimit ()
467+ while hasattr (obj , '__wrapped__' ):
468+ if stop is not None and stop (obj ):
469+ break
470+
471+ candidate = obj .__wrapped__
472+ if candidate is None :
473+ break
474+
475+ if isinstance (candidate , type ) and isinstance (orig , type ):
476+ if not (candidate .__name__ == orig .__name__
477+ and candidate .__module__ == orig .__module__ ):
478+ break
479+
480+ obj = candidate
481+ id_func = id (obj )
482+ if (id_func in memo ) or (len (memo ) >= recursion_limit ):
483+ raise ValueError ('wrapper loop when unwrapping {!r}' .format (orig ))
484+ memo [id_func ] = obj
485+
486+ return obj
487+
488+
452489def _filter_type (type : Type [T ],
453490 values : Union [Iterable ['Doc' ], Mapping [str , 'Doc' ]]) -> List [T ]:
454491 """
@@ -712,11 +749,11 @@ def __init__(self, module: Union[ModuleType, str], *,
712749 "exported in `__all__`" )
713750 else :
714751 if not _is_blacklisted (name , self ):
715- obj = inspect . unwrap (obj )
752+ obj = _unwrap_object (obj )
716753 public_objs .append ((name , obj ))
717754 else :
718755 def is_from_this_module (obj ):
719- mod = inspect .getmodule (inspect . unwrap (obj ))
756+ mod = inspect .getmodule (_unwrap_object (obj ))
720757 return mod is None or mod .__name__ == self .obj .__name__
721758
722759 for name , obj in inspect .getmembers (self .obj ):
@@ -730,7 +767,7 @@ def is_from_this_module(obj):
730767 self ._context .blacklisted .add (f'{ self .refname } .{ name } ' )
731768 continue
732769
733- obj = inspect . unwrap (obj )
770+ obj = _unwrap_object (obj )
734771 public_objs .append ((name , obj ))
735772
736773 index = list (self .obj .__dict__ ).index
@@ -1066,7 +1103,7 @@ def __init__(self, name: str, module: Module, obj, *, docstring: Optional[str] =
10661103 self .module ._context .blacklisted .add (f'{ self .refname } .{ _name } ' )
10671104 continue
10681105
1069- obj = inspect . unwrap (obj )
1106+ obj = _unwrap_object (obj )
10701107 public_objs .append ((_name , obj ))
10711108
10721109 def definition_order_index (
@@ -1428,7 +1465,7 @@ def _is_async(self):
14281465 try :
14291466 # Both of these are required because coroutines aren't classified as async
14301467 # generators and vice versa.
1431- obj = inspect . unwrap (self .obj )
1468+ obj = _unwrap_object (self .obj )
14321469 return (inspect .iscoroutinefunction (obj ) or
14331470 inspect .isasyncgenfunction (obj ))
14341471 except AttributeError :
0 commit comments