1
1
import re
2
+ import typing_extensions as type_ext
2
3
from dataclasses import dataclass
3
4
from typing import _GenericAlias # type: ignore
4
5
from typing import Optional , Sequence , Type , Union
6
+ from inspect import isclass
5
7
6
8
from dbally .context .context import BaseCallerContext
7
9
from dbally .similarity import AbstractSimilarityIndex
@@ -12,15 +14,22 @@ def parse_param_type(param_type: Union[type, _GenericAlias]) -> str:
12
14
Parses the type of a method parameter and returns a string representation of it.
13
15
14
16
Args:
15
- param_type: type of the parameter
17
+ param_type: Type of the parameter.
16
18
17
19
Returns:
18
- str: string representation of the type
20
+ A string representation of the type.
19
21
"""
20
- if param_type in {int , float , str , bool , list , dict , set , tuple }:
22
+
23
+ # TODO consider using hasattr() to ensure correctness of the IF's below
24
+ if isclass (param_type ):
21
25
return param_type .__name__
26
+
27
+ if type_ext .get_origin (param_type ) is Union :
28
+ args_str_repr = ', ' .join (parse_param_type (arg ) for arg in type_ext .get_args (param_type ))
29
+ return f"Union[{ args_str_repr } ]"
30
+
22
31
if param_type .__module__ == "typing" :
23
- return re . sub ( r"\btyping\." , "" , str ( param_type ))
32
+ return param_type . _name
24
33
25
34
return str (param_type )
26
35
0 commit comments