1
1
from dataclasses import dataclass
2
2
from inspect import isclass
3
3
from typing import _GenericAlias # type: ignore
4
- from typing import Optional , Sequence , Type , Union
4
+ from typing import Optional , Sequence , Type , Union , Generator
5
5
6
6
import typing_extensions as type_ext
7
7
@@ -22,6 +22,11 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str:
22
22
23
23
# TODO consider using hasattr() to ensure correctness of the IF's below
24
24
if isclass (param_type ):
25
+ if issubclass (param_type , BaseCallerContext ):
26
+ # this mechanism ensures the LLM will be able to notice the relation between
27
+ # the keyword-call specified in the prompt and the filter method signatures
28
+ return BaseCallerContext ._alias
29
+
25
30
return param_type .__name__
26
31
27
32
# typing.Literal['aaa', 'bbb'] edge case handler
@@ -31,16 +36,22 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str:
31
36
return f"'{ param_type } '"
32
37
33
38
if param_type .__module__ in ["typing" , "typing_extensions" ]:
39
+ param_name = param_type ._name # pylint: disable=protected-access
40
+ if param_name is None :
41
+ # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None`
42
+ # but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"`
43
+ param_name = type_ext .get_origin (param_type )._name # pylint: disable=protected-access
44
+
34
45
type_args = type_ext .get_args (param_type )
35
- if type_args :
36
- param_name = param_type . _name # pylint: disable=protected-access
37
- if param_name is None :
38
- # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None`
39
- # but at the same time: ` type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"`
40
- param_name = type_ext . get_origin ( param_type ). _name # pylint: disable=protected-access
41
-
42
- args_str_repr = ", " .join (parse_param_type ( arg ) for arg in type_args )
43
- return f"{ param_name } [{ args_str_repr } ]"
46
+ if not type_args :
47
+ return param_name
48
+
49
+ parsed_args : Generator [ str ] = ( parse_param_type ( arg ) for arg in type_args )
50
+ if type_ext .get_origin (param_type ) is Union :
51
+ return " | " . join ( parsed_args )
52
+
53
+ parsed_args_concatanated = ", " .join (parsed_args )
54
+ return f"{ param_name } [{ parsed_args_concatanated } ]"
44
55
45
56
return str (param_type )
46
57
0 commit comments