Skip to content

Commit 1294a9c

Browse files
type hint parsing changes: SomeCustomContext -> AskerContext; Union[a, b] -> a | b; removed typing & typing_extensions module name prefixes
1 parent c97ba15 commit 1294a9c

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

src/dbally/views/exposed_functions.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
from inspect import isclass
33
from typing import _GenericAlias # type: ignore
4-
from typing import Optional, Sequence, Type, Union
4+
from typing import Optional, Sequence, Type, Union, Generator
55

66
import typing_extensions as type_ext
77

@@ -22,6 +22,11 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str:
2222

2323
# TODO consider using hasattr() to ensure correctness of the IF's below
2424
if isclass(param_type):
25+
if issubclass(param_type, BaseCallerContext):
26+
# this mechanism ensures the LLM will be able to notice the relation between
27+
# the keyword-call specified in the prompt and the filter method signatures
28+
return BaseCallerContext._alias
29+
2530
return param_type.__name__
2631

2732
# typing.Literal['aaa', 'bbb'] edge case handler
@@ -31,16 +36,22 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str:
3136
return f"'{param_type}'"
3237

3338
if param_type.__module__ in ["typing", "typing_extensions"]:
39+
param_name = param_type._name # pylint: disable=protected-access
40+
if param_name is None:
41+
# workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None`
42+
# but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"`
43+
param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access
44+
3445
type_args = type_ext.get_args(param_type)
35-
if type_args:
36-
param_name = param_type._name # pylint: disable=protected-access
37-
if param_name is None:
38-
# workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None`
39-
# but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"`
40-
param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access
41-
42-
args_str_repr = ", ".join(parse_param_type(arg) for arg in type_args)
43-
return f"{param_name}[{args_str_repr}]"
46+
if not type_args:
47+
return param_name
48+
49+
parsed_args: Generator[str] = (parse_param_type(arg) for arg in type_args)
50+
if type_ext.get_origin(param_type) is Union:
51+
return " | ".join(parsed_args)
52+
53+
parsed_args_concatanated = ", ".join(parsed_args)
54+
return f"{param_name}[{parsed_args_concatanated}]"
4455

4556
return str(param_type)
4657

0 commit comments

Comments
 (0)