|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import re
|
| 16 | +import sys |
16 | 17 | from types import FunctionType
|
17 | 18 | from typing import ( # type: ignore
|
18 | 19 | Any,
|
| 20 | + Dict, |
19 | 21 | List,
|
20 | 22 | Match,
|
| 23 | + Optional, |
21 | 24 | Union,
|
22 | 25 | cast,
|
23 | 26 | get_args,
|
24 | 27 | get_origin,
|
25 |
| - get_type_hints, |
26 | 28 | )
|
| 29 | +from typing import get_type_hints as typing_get_type_hints |
27 | 30 |
|
28 | 31 | from playwright._impl._accessibility import Accessibility
|
29 | 32 | from playwright._impl._assertions import (
|
@@ -287,3 +290,34 @@ def return_value(value: Any) -> List[str]:
|
287 | 290 |
|
288 | 291 | api_globals = globals()
|
289 | 292 | assert Serializable
|
| 293 | + |
| 294 | +# Python 3.11+ does not treat default args with None as Optional anymore, this wrapper will still wrap them. |
| 295 | +# https://github.com/python/cpython/issues/90353 |
| 296 | +def get_type_hints(func: Any, globalns: Any) -> Dict[str, Any]: |
| 297 | + original_value = typing_get_type_hints(func, globalns) |
| 298 | + if sys.version_info < (3, 11): |
| 299 | + return original_value |
| 300 | + for key, value in _get_defaults(func).items(): |
| 301 | + if value is None and original_value[key] is not Optional: |
| 302 | + original_value[key] = Optional[original_value[key]] |
| 303 | + return original_value |
| 304 | + |
| 305 | + |
| 306 | +def _get_defaults(func: Any) -> Dict[str, Any]: |
| 307 | + """Internal helper to extract the default arguments, by name.""" |
| 308 | + try: |
| 309 | + code = func.__code__ |
| 310 | + except AttributeError: |
| 311 | + # Some built-in functions don't have __code__, __defaults__, etc. |
| 312 | + return {} |
| 313 | + pos_count = code.co_argcount |
| 314 | + arg_names = code.co_varnames |
| 315 | + arg_names = arg_names[:pos_count] |
| 316 | + defaults = func.__defaults__ or () |
| 317 | + kwdefaults = func.__kwdefaults__ |
| 318 | + res = dict(kwdefaults) if kwdefaults else {} |
| 319 | + pos_offset = pos_count - len(defaults) |
| 320 | + for name, value in zip(arg_names[pos_offset:], defaults): |
| 321 | + assert name not in res |
| 322 | + res[name] = value |
| 323 | + return res |
0 commit comments