|
1 |
| -import re |
2 | 1 | from dataclasses import dataclass
|
| 2 | +from inspect import isclass |
3 | 3 | from typing import _GenericAlias # type: ignore
|
4 | 4 | from typing import Optional, Sequence, Type, Union
|
5 | 5 |
|
| 6 | +import typing_extensions as type_ext |
| 7 | + |
6 | 8 | from dbally.context.context import BaseCallerContext
|
7 | 9 | from dbally.similarity import AbstractSimilarityIndex
|
8 | 10 |
|
9 | 11 |
|
10 |
| -def parse_param_type(param_type: Union[type, _GenericAlias]) -> str: |
| 12 | +def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: |
11 | 13 | """
|
12 | 14 | Parses the type of a method parameter and returns a string representation of it.
|
13 | 15 |
|
14 | 16 | Args:
|
15 |
| - param_type: type of the parameter |
| 17 | + param_type: Type of the parameter. |
16 | 18 |
|
17 | 19 | Returns:
|
18 |
| - str: string representation of the type |
| 20 | + A string representation of the type. |
19 | 21 | """
|
20 |
| - if param_type in {int, float, str, bool, list, dict, set, tuple}: |
| 22 | + |
| 23 | + # TODO consider using hasattr() to ensure correctness of the IF's below |
| 24 | + if isclass(param_type): |
21 | 25 | return param_type.__name__
|
22 |
| - if param_type.__module__ == "typing": |
23 |
| - return re.sub(r"\btyping\.", "", str(param_type)) |
| 26 | + |
| 27 | + # typing.Literal['aaa', 'bbb'] edge case handler |
| 28 | + # the args are strings not types thus isclass('aaa') is False |
| 29 | + # at the same type string has no __module__ property which causes an error |
| 30 | + if isinstance(param_type, str): |
| 31 | + return f"'{param_type}'" |
| 32 | + |
| 33 | + if param_type.__module__ == "typing" or param_type.__module__ == "typing_extensions": |
| 34 | + type_args = type_ext.get_args(param_type) |
| 35 | + if type_args: |
| 36 | + param_name = param_type._name # pylint: disable=protected-access |
| 37 | + if param_name is None: |
| 38 | + # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None` |
| 39 | + # but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"` |
| 40 | + param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access |
| 41 | + |
| 42 | + args_str_repr = ", ".join(parse_param_type(arg) for arg in type_args) |
| 43 | + return f"{param_name}[{args_str_repr}]" |
24 | 44 |
|
25 | 45 | return str(param_type)
|
26 | 46 |
|
|
0 commit comments