|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +from collections.abc import Callable |
5 | 6 | from contextlib import suppress |
6 | 7 | from dataclasses import dataclass |
7 | 8 | from functools import cache |
8 | 9 | from importlib import import_module |
9 | 10 | from inspect import isclass, isroutine |
10 | | -from typing import Any, Callable, Union |
| 11 | +from types import UnionType |
| 12 | +from typing import Any, Union, get_type_hints |
11 | 13 |
|
12 | 14 | from sphinx_codeautolink.parse import Name, NameBreak |
13 | 15 |
|
@@ -116,34 +118,27 @@ def call_value(cursor: Cursor) -> None: |
116 | 118 |
|
117 | 119 | def get_return_annotation(func: Callable) -> type | None: |
118 | 120 | """Determine the target of a function return type hint.""" |
119 | | - annotations = getattr(func, "__annotations__", {}) |
120 | | - ret_annotation = annotations.get("return", None) |
| 121 | + annotation = get_type_hints(func).get("return") |
121 | 122 |
|
122 | 123 | # Inner type from typing.Optional or Union[None, T] |
123 | | - origin = getattr(ret_annotation, "__origin__", None) |
124 | | - args = getattr(ret_annotation, "__args__", None) |
125 | | - if origin is Union and len(args) == 2: # noqa: PLR2004 |
| 124 | + origin = getattr(annotation, "__origin__", None) |
| 125 | + args = getattr(annotation, "__args__", None) |
| 126 | + if (origin is Union or isinstance(annotation, UnionType)) and len(args) == 2: # noqa: PLR2004 |
126 | 127 | nonetype = type(None) |
127 | 128 | if args[0] is nonetype: |
128 | | - ret_annotation = args[1] |
| 129 | + annotation = args[1] |
129 | 130 | elif args[1] is nonetype: |
130 | | - ret_annotation = args[0] |
131 | | - |
132 | | - # Try to resolve a string annotation in the module scope |
133 | | - if isinstance(ret_annotation, str): |
134 | | - location = fully_qualified_name(func) |
135 | | - mod, _ = closest_module(tuple(location.split("."))) |
136 | | - ret_annotation = getattr(mod, ret_annotation, ret_annotation) |
| 131 | + annotation = args[0] |
137 | 132 |
|
138 | 133 | if ( |
139 | | - not ret_annotation |
140 | | - or not isinstance(ret_annotation, type) |
141 | | - or hasattr(ret_annotation, "__origin__") |
| 134 | + not annotation |
| 135 | + or not isinstance(annotation, type) |
| 136 | + or hasattr(annotation, "__origin__") |
142 | 137 | ): |
143 | 138 | msg = f"Unable to follow return annotation of {get_name_for_debugging(func)}." |
144 | 139 | raise CouldNotResolve(msg) |
145 | 140 |
|
146 | | - return ret_annotation |
| 141 | + return annotation |
147 | 142 |
|
148 | 143 |
|
149 | 144 | def fully_qualified_name(thing: type | Callable) -> str: |
|
0 commit comments