|
| 1 | +"""Mock specification.""" |
| 2 | +import inspect |
| 3 | +import functools |
| 4 | +import warnings |
| 5 | +from typing import Any, Dict, NamedTuple, Optional, Tuple, Type, Union, get_type_hints |
| 6 | + |
| 7 | +from .warnings import IncorrectCallWarning |
| 8 | + |
| 9 | + |
| 10 | +class BoundArgs(NamedTuple): |
| 11 | + """Arguments bound to a spec.""" |
| 12 | + |
| 13 | + args: Tuple[Any, ...] |
| 14 | + kwargs: Dict[str, Any] |
| 15 | + |
| 16 | + |
| 17 | +class Spec: |
| 18 | + """Interface defining a Spy's specification. |
| 19 | +
|
| 20 | + Arguments: |
| 21 | + source: The source object for the specification. |
| 22 | + name: The spec's name. If left unspecified, will be derived from |
| 23 | + `source`, if possible. Will fallback to a default value. |
| 24 | + module_name: The spec's module name. If left unspecified or `True`, |
| 25 | + will be derived from `source`, if possible. If explicitly set to `None` |
| 26 | + or `False` or it is unable to be derived, a module name will not be used. |
| 27 | +
|
| 28 | + """ |
| 29 | + |
| 30 | + _DEFAULT_SPY_NAME = "unnamed" |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + source: Optional[Any], |
| 35 | + name: Optional[str], |
| 36 | + module_name: Union[str, bool, None] = True, |
| 37 | + ) -> None: |
| 38 | + self._source = source |
| 39 | + |
| 40 | + if name is not None: |
| 41 | + self._name = name |
| 42 | + elif source is not None: |
| 43 | + self._name = getattr(source, "__name__", self._DEFAULT_SPY_NAME) |
| 44 | + else: |
| 45 | + self._name = self._DEFAULT_SPY_NAME |
| 46 | + |
| 47 | + if isinstance(module_name, str): |
| 48 | + self._module_name: Optional[str] = module_name |
| 49 | + elif module_name is True and source is not None: |
| 50 | + self._module_name = getattr(source, "__module__", None) |
| 51 | + else: |
| 52 | + self._module_name = None |
| 53 | + |
| 54 | + def get_name(self) -> str: |
| 55 | + """Get the Spec's human readable name. |
| 56 | +
|
| 57 | + Name may be manually specified or derived from the object the Spec |
| 58 | + represents. |
| 59 | + """ |
| 60 | + return self._name |
| 61 | + |
| 62 | + def get_full_name(self) -> str: |
| 63 | + """Get the full name of the spec. |
| 64 | +
|
| 65 | + Full name includes the module name of the object the Spec represents, |
| 66 | + if available. |
| 67 | + """ |
| 68 | + name = self._name |
| 69 | + module_name = self._module_name |
| 70 | + return f"{module_name}.{name}" if module_name else name |
| 71 | + |
| 72 | + def get_signature(self) -> Optional[inspect.Signature]: |
| 73 | + """Get the Spec's signature, if Spec represents a callable.""" |
| 74 | + try: |
| 75 | + return inspect.signature(self._source) # type: ignore[arg-type] |
| 76 | + except TypeError: |
| 77 | + return None |
| 78 | + |
| 79 | + def get_class_type(self) -> Optional[Type[Any]]: |
| 80 | + """Get the Spec's class type, if Spec represents a class.""" |
| 81 | + return self._source if inspect.isclass(self._source) else None |
| 82 | + |
| 83 | + def get_is_async(self) -> bool: |
| 84 | + """Get whether the Spec represents an async. callable.""" |
| 85 | + source = self._source |
| 86 | + |
| 87 | + # `iscoroutinefunction` does not work for `partial` on Python < 3.8 |
| 88 | + if isinstance(source, functools.partial): |
| 89 | + source = source.func |
| 90 | + |
| 91 | + # check if spec source is a class with a __call__ method |
| 92 | + elif inspect.isclass(source): |
| 93 | + call_method = inspect.getattr_static(source, "__call__", None) |
| 94 | + if inspect.isfunction(call_method): |
| 95 | + source = call_method |
| 96 | + |
| 97 | + return inspect.iscoroutinefunction(source) |
| 98 | + |
| 99 | + def bind_args(self, *args: Any, **kwargs: Any) -> BoundArgs: |
| 100 | + """Bind given args and kwargs to the Spec's signature, if possible. |
| 101 | +
|
| 102 | + If no signature or unable to bind, will simply pass args and kwargs |
| 103 | + through without modification. |
| 104 | + """ |
| 105 | + signature = self.get_signature() |
| 106 | + |
| 107 | + if signature: |
| 108 | + try: |
| 109 | + bound_args = signature.bind(*args, **kwargs) |
| 110 | + except TypeError as e: |
| 111 | + # stacklevel: 4 ensures warning is linked to call location |
| 112 | + warnings.warn(IncorrectCallWarning(e), stacklevel=4) |
| 113 | + else: |
| 114 | + args = bound_args.args |
| 115 | + kwargs = bound_args.kwargs |
| 116 | + |
| 117 | + return BoundArgs(args=args, kwargs=kwargs) |
| 118 | + |
| 119 | + def get_child_spec(self, name: str) -> "Spec": |
| 120 | + """Get a child attribute, property, or method's Spec from this Spec.""" |
| 121 | + source = self._source |
| 122 | + child_name = f"{self._name}.{name}" |
| 123 | + child_source = None |
| 124 | + |
| 125 | + if inspect.isclass(source): |
| 126 | + # use type hints to get child spec for class attributes |
| 127 | + child_hint = _get_type_hints(source).get(name) |
| 128 | + # use inspect to get child spec for methods and properties |
| 129 | + child_source = inspect.getattr_static(source, name, child_hint) |
| 130 | + |
| 131 | + if isinstance(child_source, property): |
| 132 | + child_source = _get_type_hints(child_source.fget).get("return") |
| 133 | + |
| 134 | + elif isinstance(child_source, staticmethod): |
| 135 | + child_source = child_source.__func__ |
| 136 | + |
| 137 | + elif inspect.isfunction(child_source): |
| 138 | + # consume the `self` argument of the method to ensure proper |
| 139 | + # signature reporting by wrapping it in a partial |
| 140 | + child_source = functools.partial(child_source, None) |
| 141 | + |
| 142 | + return Spec(source=child_source, name=child_name, module_name=self._module_name) |
| 143 | + |
| 144 | + |
| 145 | +def _get_type_hints(obj: Any) -> Dict[str, Any]: |
| 146 | + """Get type hints for an object, if possible. |
| 147 | +
|
| 148 | + The builtin `typing.get_type_hints` may fail at runtime, |
| 149 | + e.g. if a type is subscriptable according to mypy but not |
| 150 | + according to Python. |
| 151 | + """ |
| 152 | + try: |
| 153 | + return get_type_hints(obj) |
| 154 | + except Exception: |
| 155 | + return {} |
0 commit comments