Skip to content

Commit 73741d9

Browse files
type hint lifting
1 parent 308e2e1 commit 73741d9

File tree

7 files changed

+30
-39
lines changed

7 files changed

+30
-39
lines changed

src/dbally/collection/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dbally.similarity.index import AbstractSimilarityIndex
1717
from dbally.view_selection.base import ViewSelector
1818
from dbally.views.base import BaseView, IndexLocation
19-
from dbally.context.context import BaseCallerContext
19+
from dbally.context.context import BaseCallerContext, CustomContextsList
2020

2121

2222
class Collection:
@@ -157,7 +157,7 @@ async def ask(
157157
dry_run: bool = False,
158158
return_natural_response: bool = False,
159159
llm_options: Optional[LLMOptions] = None,
160-
context: Optional[List[BaseCallerContext]] = None
160+
context: Optional[CustomContextsList] = None
161161
) -> ExecutionResult:
162162
"""
163163
Ask question in a text form and retrieve the answer based on the available views.

src/dbally/context/_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
import typing
21
import typing_extensions as type_ext
32

3+
from typing import Sequence, Tuple, Optional, Type, Any, Union
44
from inspect import isclass
55

66
from dbally.context.context import BaseCallerContext
77
from dbally.views.exposed_functions import MethodParamWithTyping
88

99

1010
def _extract_params_and_context(
11-
filter_method_: typing.Callable, hidden_args: typing.List[str]
12-
) -> typing.Tuple[
13-
typing.List[MethodParamWithTyping],
14-
typing.Optional[typing.Type[BaseCallerContext]]
15-
]:
11+
filter_method_: type_ext.Callable, hidden_args: Sequence[str]
12+
) -> Tuple[Sequence[MethodParamWithTyping], Optional[Type[BaseCallerContext]]]:
1613
"""
1714
Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format.
1815
Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext
@@ -29,7 +26,7 @@ class is returned.
2926
params = []
3027
context = None
3128
# TODO confirm whether to use typing.get_type_hints(method) or method.__annotations__
32-
for name_, type_ in typing.get_type_hints(filter_method_).items():
29+
for name_, type_ in type_ext.get_type_hints(filter_method_).items():
3330
if name_ in hidden_args:
3431
continue
3532

@@ -39,11 +36,11 @@ class is returned.
3936
# this is the case when user provides a context but no other type hint for a specifc arg
4037
# TODO confirm whether this case should be supported
4138
context = type_
42-
type_ = typing.Any
43-
elif type_ext.get_origin(type_) is typing.Union:
39+
type_ = Any
40+
elif type_ext.get_origin(type_) is Union:
4441
union_subtypes = type_ext.get_args(type_)
4542
if not union_subtypes:
46-
type_ = typing.Any
43+
type_ = Any
4744

4845
for subtype_ in union_subtypes: # type: ignore
4946
# TODO add custom error for the situation when user provides more than two contexts for a single filter
@@ -58,7 +55,7 @@ class is returned.
5855

5956

6057
def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool:
61-
if type_ext.get_origin(arg.type) is not typing.Union and not issubclass(arg.type, BaseCallerContext):
58+
if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext):
6259
return False
6360

6461
for subtype in type_ext.get_args(arg.type):

src/dbally/context/context.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import ast
22

3-
from typing import List, Optional, Type, TypeVar
3+
from typing import Optional, Type, Sequence
44
from typing_extensions import Self
55
from pydantic import BaseModel
66

77
from dbally.context.exceptions import ContextNotAvailableError
88

99

10-
T = TypeVar('T', bound='BaseCallerContext')
11-
AllCallerContexts = Optional[List[T]] # TODO confirm the naming
10+
CustomContextsList = Sequence[Type['BaseCallerContext']] # TODO confirm the naming
1211

1312

1413
class BaseCallerContext(BaseModel):
@@ -18,7 +17,7 @@ class BaseCallerContext(BaseModel):
1817
"""
1918

2019
@classmethod
21-
def select_context(cls, contexts: List[T]) -> T:
20+
def select_context(cls, contexts: Sequence[Type[Self]]) -> Type[Self]:
2221
if not contexts:
2322
raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.")
2423

src/dbally/iql/_processor.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import ast
22

3-
from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Dict
4-
from typing_extensions import Callable
3+
from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Type
54

65
from dbally.audit.event_tracker import EventTracker
76
from dbally.iql import syntax
@@ -13,7 +12,7 @@
1312
IQLUnsupportedSyntaxError,
1413
)
1514
from dbally.iql._type_validators import validate_arg_type
16-
from dbally.context.context import BaseCallerContext
15+
from dbally.context.context import BaseCallerContext, CustomContextsList
1716
from dbally.context.exceptions import ContextNotAvailableError, ContextualisationNotAllowed
1817
from dbally.context._utils import _extract_params_and_context, _does_arg_allow_context
1918
from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction
@@ -25,15 +24,15 @@ class IQLProcessor:
2524
"""
2625
source: str
2726
allowed_functions: Mapping[str, "ExposedFunction"]
28-
contexts: List[BaseCallerContext]
27+
contexts: CustomContextsList
2928
_event_tracker: EventTracker
3029

3130

3231
def __init__(
3332
self,
3433
source: str,
3534
allowed_functions: List["ExposedFunction"],
36-
contexts: Optional[List[BaseCallerContext]] = None,
35+
contexts: Optional[CustomContextsList] = None,
3736
event_tracker: Optional[EventTracker] = None
3837
) -> None:
3938
self.source = source
@@ -52,7 +51,7 @@ async def process(self) -> syntax.Node:
5251
Raises:
5352
IQLError: if parsing fails.
5453
"""
55-
# TODO adjust this method to prevent making context class constructor calls lowercase
54+
5655
self.source = self._to_lower_except_in_quotes(self.source, ["AND", "OR", "NOT"])
5756

5857
ast_tree = ast.parse(self.source)
@@ -132,15 +131,11 @@ def _parse_arg(
132131
if parent_func_def.context_class is None:
133132
raise ContextualisationNotAllowed("The LLM detected that the context is required to execute the query while the filter signature does not allow it at all.")
134133

135-
if _does_arg_allow_context(arg_spec):
134+
if not _does_arg_allow_context(arg_spec):
135+
print(arg_spec)
136136
raise ContextualisationNotAllowed(f"The LLM detected that the context is required to execute the query while the filter signature does allow it for `{arg_spec.name}` argument.")
137137

138-
context = parent_func_def.context_class.select_context(self.contexts)
139-
140-
try:
141-
return getattr(context, arg_spec.name)
142-
except AttributeError:
143-
raise ContextNotAvailableError(f"The LLM detected that the context is required to execute the query and the context object was provided but it is missing the `{arg_spec.name}` field.")
138+
return parent_func_def.context_class.select_context(self.contexts)
144139

145140
if not isinstance(arg, ast.Constant):
146141
raise IQLArgumentParsingError(arg, self.source)

src/dbally/iql/_query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import TYPE_CHECKING, List, Optional
1+
from typing import TYPE_CHECKING, List, Optional, Type
22

33
from ..audit.event_tracker import EventTracker
44
from . import syntax
55
from ._processor import IQLProcessor
6-
from dbally.context.context import BaseCallerContext
6+
from dbally.context.context import BaseCallerContext, CustomContextsList
77

88
if TYPE_CHECKING:
99
from dbally.views.structured import ExposedFunction
@@ -25,7 +25,7 @@ async def parse(
2525
source: str,
2626
allowed_functions: List["ExposedFunction"],
2727
event_tracker: Optional[EventTracker] = None,
28-
context: Optional[List[BaseCallerContext]] = None
28+
context: Optional[CustomContextsList] = None
2929
) -> "IQLQuery":
3030
"""
3131
Parse IQL string to IQLQuery object.

src/dbally/views/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import abc
2-
from typing import Dict, List, Optional, Tuple
2+
from typing import Dict, List, Optional, Tuple, Type
33

44
from dbally.audit.event_tracker import EventTracker
55
from dbally.collection.results import ViewExecutionResult
66
from dbally.llms.base import LLM
77
from dbally.llms.clients.base import LLMOptions
88
from dbally.similarity import AbstractSimilarityIndex
9-
from dbally.context.context import BaseCallerContext
9+
from dbally.context.context import BaseCallerContext, CustomContextsList
1010

1111
IndexLocation = Tuple[str, str, str]
1212

@@ -26,7 +26,7 @@ async def ask(
2626
n_retries: int = 3,
2727
dry_run: bool = False,
2828
llm_options: Optional[LLMOptions] = None,
29-
context: Optional[List[BaseCallerContext]] = None
29+
context: Optional[CustomContextsList] = None
3030
) -> ViewExecutionResult:
3131
"""
3232
Executes the query and returns the result.

src/dbally/views/structured.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
from collections import defaultdict
3-
from typing import Dict, List, Optional
3+
from typing import Dict, List, Optional, Type
44

55
from dbally.audit.event_tracker import EventTracker
66
from dbally.collection.results import ViewExecutionResult
@@ -10,7 +10,7 @@
1010
from dbally.llms.base import LLM
1111
from dbally.llms.clients.base import LLMOptions
1212
from dbally.views.exposed_functions import ExposedFunction
13-
from dbally.context.context import BaseCallerContext
13+
from dbally.context.context import BaseCallerContext, CustomContextsList
1414

1515
from ..similarity import AbstractSimilarityIndex
1616
from .base import BaseView, IndexLocation
@@ -42,7 +42,7 @@ async def ask(
4242
n_retries: int = 3,
4343
dry_run: bool = False,
4444
llm_options: Optional[LLMOptions] = None,
45-
context: Optional[List[BaseCallerContext]] = None
45+
context: Optional[CustomContextsList] = None
4646
) -> ViewExecutionResult:
4747
"""
4848
Executes the query and returns the result. It generates the IQL query from the natural language query\

0 commit comments

Comments
 (0)