Skip to content

Commit 09bac55

Browse files
reworked parse_param_type() function to increase performance, generality and properly handle types: Union[Type1, Type2, ...], __main__.SomeCustomClass
1 parent 5fd802f commit 09bac55

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

src/dbally/views/exposed_functions.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,46 @@
1-
import re
21
from dataclasses import dataclass
2+
from inspect import isclass
33
from typing import _GenericAlias # type: ignore
44
from typing import Optional, Sequence, Type, Union
55

6+
import typing_extensions as type_ext
7+
68
from dbally.context.context import BaseCallerContext
79
from dbally.similarity import AbstractSimilarityIndex
810

911

10-
def parse_param_type(param_type: Union[type, _GenericAlias]) -> str:
12+
def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str:
1113
"""
1214
Parses the type of a method parameter and returns a string representation of it.
1315
1416
Args:
15-
param_type: type of the parameter
17+
param_type: Type of the parameter.
1618
1719
Returns:
18-
str: string representation of the type
20+
A string representation of the type.
1921
"""
20-
if param_type in {int, float, str, bool, list, dict, set, tuple}:
22+
23+
# TODO consider using hasattr() to ensure correctness of the IF's below
24+
if isclass(param_type):
2125
return param_type.__name__
22-
if param_type.__module__ == "typing":
23-
return re.sub(r"\btyping\.", "", str(param_type))
26+
27+
# typing.Literal['aaa', 'bbb'] edge case handler
28+
# the args are strings not types thus isclass('aaa') is False
29+
# at the same type string has no __module__ property which causes an error
30+
if isinstance(param_type, str):
31+
return f"'{param_type}'"
32+
33+
if param_type.__module__ == "typing" or param_type.__module__ == "typing_extensions":
34+
type_args = type_ext.get_args(param_type)
35+
if type_args:
36+
param_name = param_type._name # pylint: disable=protected-access
37+
if param_name is None:
38+
# workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None`
39+
# but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"`
40+
param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access
41+
42+
args_str_repr = ", ".join(parse_param_type(arg) for arg in type_args)
43+
return f"{param_name}[{args_str_repr}]"
2444

2545
return str(param_type)
2646

0 commit comments

Comments
 (0)