Skip to content

Commit 623effd

Browse files
EXPERIMENTAL: reworked context injection such it is handled immediately in 'structured_view.ask()' and than stored in 'ExposedFunction' instances
1 parent a154577 commit 623effd

File tree

14 files changed

+97
-68
lines changed

14 files changed

+97
-68
lines changed

src/dbally/collection/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dbally.audit.events import RequestEnd, RequestStart
1111
from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
1212
from dbally.collection.results import ExecutionResult
13-
from dbally.context.context import CustomContext
13+
from dbally.context.context import BaseCallerContext
1414
from dbally.llms.base import LLM
1515
from dbally.llms.clients.base import LLMOptions
1616
from dbally.nl_responder.nl_responder import NLResponder
@@ -157,7 +157,7 @@ async def ask(
157157
dry_run: bool = False,
158158
return_natural_response: bool = False,
159159
llm_options: Optional[LLMOptions] = None,
160-
contexts: Optional[Iterable[CustomContext]] = None,
160+
contexts: Optional[Iterable[BaseCallerContext]] = None,
161161
) -> ExecutionResult:
162162
"""
163163
Ask question in a text form and retrieve the answer based on the available views.

src/dbally/context/context.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
from abc import ABC
33
from typing import ClassVar, Iterable
44

5-
from typing_extensions import Self, TypeAlias
5+
from typing_extensions import Self
66

7-
from dbally.context.exceptions import ContextNotAvailableError
8-
9-
CustomContext: TypeAlias = "BaseCallerContext"
7+
from dbally.context.exceptions import BaseContextError
108

119

1210
class BaseCallerContext(ABC):
@@ -23,7 +21,7 @@ class BaseCallerContext(ABC):
2321
alias: ClassVar[str] = "AskerContext"
2422

2523
@classmethod
26-
def select_context(cls, contexts: Iterable[CustomContext]) -> Self:
24+
def select_context(cls, contexts: Iterable["BaseCallerContext"]) -> Self:
2725
"""
2826
Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being
2927
an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context
@@ -36,17 +34,17 @@ class by its right instance.
3634
An instance of the same BaseCallerContext subclass this method is caller from.
3735
3836
Raises:
39-
ContextNotAvailableError: If the sequence of context objects passed as argument is empty.
37+
BaseContextError: If no element in `contexts` matches `cls` class.
4038
"""
4139

42-
if not contexts:
43-
raise ContextNotAvailableError(
44-
"The LLM detected that the context is required to execute the query"
45-
"and the filter signature allows contextualization while the context was not provided."
46-
)
40+
try:
41+
selected_context = next(filter(lambda obj: isinstance(obj, cls), contexts))
42+
except StopIteration as e:
43+
# this custom exception provides more clear message what have just gone wrong
44+
raise BaseContextError() from e
4745

4846
# TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore`
49-
return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore
47+
return selected_context # type: ignore
5048

5149
@classmethod
5250
def is_context_call(cls, node: ast.expr) -> bool:

src/dbally/context/exceptions.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
1-
class ContextNotAvailableError(Exception):
1+
class BaseContextError(Exception):
22
"""
3-
An exception inheriting from BaseContextException pointining that no sufficient context information
4-
was provided by the user while calling view.ask().
3+
A base error for context handling logic.
54
"""
5+
6+
7+
class SuitableContextNotProvidedError(BaseContextError):
8+
"""
9+
Raised when method argument type hint points that a contextualization is available
10+
but not suitable context was provided.
11+
"""
12+
13+
def __init__(self, filter_fun_signature: str, context_class_name: str) -> None:
14+
# this syntax 'or BaseCallerContext' is just to prevent type checkers
15+
# from raising a warning, as filter_.context_class can be None. It's essenially a fallback that should never
16+
# be reached, unless somebody will use this Exception against its purpose.
17+
# TODO consider raising a warning/error when this happens.
18+
19+
message = (
20+
f"No context of class {context_class_name} was provided"
21+
f"while the filter {filter_fun_signature} requires it."
22+
)
23+
super().__init__(message)

src/dbally/iql/_processor.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from dbally.audit.event_tracker import EventTracker
55
from dbally.context._utils import _does_arg_allow_context
6-
from dbally.context.context import BaseCallerContext, CustomContext
6+
from dbally.context.context import BaseCallerContext
77
from dbally.iql import syntax
88
from dbally.iql._exceptions import (
99
IQLArgumentParsingError,
@@ -23,21 +23,17 @@ class IQLProcessor:
2323
2424
Attributes:
2525
source: Raw LLM response containing IQL filter calls.
26-
allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.
27-
contexts: A sequence (typically a list) of context objects, each being an instance of
28-
a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query.
26+
allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.=
2927
"""
3028

3129
source: str
3230
allowed_functions: Mapping[str, "ExposedFunction"]
33-
contexts: Iterable[CustomContext]
3431
_event_tracker: EventTracker
3532

3633
def __init__(
3734
self,
3835
source: str,
3936
allowed_functions: Iterable[ExposedFunction],
40-
contexts: Optional[Iterable[CustomContext]] = None,
4137
event_tracker: Optional[EventTracker] = None,
4238
) -> None:
4339
"""
@@ -46,14 +42,11 @@ def __init__(
4642
Args:
4743
source: Raw LLM response containing IQL filter calls.
4844
allowed_functions: An interable (typically a list) of all filters implemented for a certain View.
49-
contexts: An iterable (typically a list) of context objects, each being an instance of
50-
a subclass of BaseCallerContext.
5145
even_tracker: An EvenTracker instance.
5246
"""
5347

5448
self.source = source
5549
self.allowed_functions = {func.name: func for func in allowed_functions}
56-
self.contexts = contexts or []
5750
self._event_tracker = event_tracker or EventTracker()
5851

5952
async def process(self) -> syntax.Node:
@@ -148,7 +141,7 @@ def _parse_arg(
148141
if not _does_arg_allow_context(arg_spec):
149142
raise IQLContextNotAllowedError(arg, self.source, arg_name=arg_spec.name)
150143

151-
return parent_func_def.context_class.select_context(self.contexts)
144+
return parent_func_def.context
152145

153146
if not isinstance(arg, ast.Constant):
154147
raise IQLArgumentParsingError(arg, self.source)

src/dbally/iql/_query.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import TYPE_CHECKING, Iterable, List, Optional
1+
from typing import TYPE_CHECKING, List, Optional
22

33
from typing_extensions import Self
44

5-
from dbally.context.context import CustomContext
6-
75
from ..audit.event_tracker import EventTracker
86
from . import syntax
97
from ._processor import IQLProcessor
@@ -28,11 +26,7 @@ def __str__(self) -> str:
2826

2927
@classmethod
3028
async def parse(
31-
cls,
32-
source: str,
33-
allowed_functions: List["ExposedFunction"],
34-
event_tracker: Optional[EventTracker] = None,
35-
contexts: Optional[Iterable[CustomContext]] = None,
29+
cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None
3630
) -> Self:
3731
"""
3832
Parse IQL string to IQLQuery object.
@@ -41,11 +35,10 @@ async def parse(
4135
source: IQL string that needs to be parsed
4236
allowed_functions: list of IQL functions that are allowed for this query
4337
event_tracker: EventTracker object to track events
44-
contexts: An iterable (typically a list) of context objects, each being
45-
an instance of a subclass of BaseCallerContext.
38+
4639
Returns:
4740
IQLQuery object
4841
"""
4942

50-
root = await IQLProcessor(source, allowed_functions, contexts, event_tracker).process()
43+
root = await IQLProcessor(source, allowed_functions, event_tracker).process()
5144
return cls(root=root, source=source)

src/dbally/iql/_type_validators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) ->
7070
actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type
7171
# typing.Union is an instance of _GenericAlias
7272
if actual_type is None:
73-
# workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary
73+
# workaround to prevent type warning in line `if isisntance(value, actual_type):`, TODO check whether necessary
7474
actual_type = required_type.__origin__
7575

7676
if actual_type is Union:

src/dbally/iql_generator/iql_generator.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Iterable, List, Optional
1+
from typing import List, Optional
22

33
from dbally.audit.event_tracker import EventTracker
4-
from dbally.context.context import CustomContext
54
from dbally.iql import IQLError, IQLQuery
65
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
76
from dbally.llms.base import LLM
@@ -43,7 +42,6 @@ async def generate_iql(
4342
examples: Optional[List[FewShotExample]] = None,
4443
llm_options: Optional[LLMOptions] = None,
4544
n_retries: int = 3,
46-
contexts: Optional[Iterable[CustomContext]] = None,
4745
) -> IQLQuery:
4846
"""
4947
Generates IQL in text form using LLM.
@@ -55,8 +53,6 @@ async def generate_iql(
5553
examples: List of examples to be injected into the conversation.
5654
llm_options: Options to use for the LLM client.
5755
n_retries: Number of retries to regenerate IQL in case of errors.
58-
contexts: An iterable (typically a list) of context objects, each being
59-
an instance of a subclass of BaseCallerContext.
6056
6157
Returns:
6258
Generated IQL query.
@@ -78,9 +74,7 @@ async def generate_iql(
7874
# TODO: Move response parsing to llm generate_text method
7975
iql = formatted_prompt.response_parser(response)
8076
# TODO: Move IQL query parsing to prompt response parser
81-
return await IQLQuery.parse(
82-
source=iql, allowed_functions=filters, event_tracker=event_tracker, contexts=contexts
83-
)
77+
return await IQLQuery.parse(source=iql, allowed_functions=filters, event_tracker=event_tracker)
8478
except IQLError as exc:
8579
# TODO handle the possibility of variable `response` being not initialized
8680
# while runnning the following line

src/dbally/views/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from dbally.audit.event_tracker import EventTracker
77
from dbally.collection.results import ViewExecutionResult
8-
from dbally.context.context import CustomContext
8+
from dbally.context.context import BaseCallerContext
99
from dbally.llms.base import LLM
1010
from dbally.llms.clients.base import LLMOptions
1111
from dbally.prompt.elements import FewShotExample
@@ -29,7 +29,7 @@ async def ask(
2929
n_retries: int = 3,
3030
dry_run: bool = False,
3131
llm_options: Optional[LLMOptions] = None,
32-
contexts: Optional[Iterable[CustomContext]] = None,
32+
contexts: Optional[Iterable[BaseCallerContext]] = None,
3333
) -> ViewExecutionResult:
3434
"""
3535
Executes the query and returns the result.
@@ -59,9 +59,9 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc
5959

6060
def list_few_shots(self) -> List[FewShotExample]:
6161
"""
62-
List all examples to be injected into few-shot prompt.
62+
Lists all examples to be injected into few-shot prompt.
6363
6464
Returns:
65-
List of few-shot examples
65+
List of few-shot examples.
6666
"""
6767
return []

src/dbally/views/exposed_functions.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from dataclasses import dataclass
22
from inspect import isclass
33
from typing import _GenericAlias # type: ignore
4-
from typing import Generator, Optional, Sequence, Type, Union
4+
from typing import Generator, Iterable, Optional, Sequence, Type, Union
55

66
import typing_extensions as type_ext
77

88
from dbally.context.context import BaseCallerContext
9+
from dbally.context.exceptions import BaseContextError, SuitableContextNotProvidedError
910
from dbally.similarity import AbstractSimilarityIndex
1011

1112

@@ -127,6 +128,7 @@ class ExposedFunction:
127128
description: str
128129
parameters: Sequence[MethodParamWithTyping]
129130
context_class: Optional[Type[BaseCallerContext]] = None
131+
context: Optional[BaseCallerContext] = None
130132

131133
def __str__(self) -> str:
132134
base_str = f"{self.name}({', '.join(str(param) for param in self.parameters)})"
@@ -135,3 +137,22 @@ def __str__(self) -> str:
135137
return f"{base_str} - {self.description}"
136138

137139
return base_str
140+
141+
def inject_context(self, contexts: Iterable[BaseCallerContext]) -> None:
142+
"""
143+
Inserts reference to the member of `contexts` of the proper class in self.context.
144+
145+
Args:
146+
contexts: An iterable of user-provided context objects.
147+
148+
Raises:
149+
SuitableContextNotProvidedError: Ff no element in `contexts` matches `self.context_class`.
150+
"""
151+
152+
if self.context_class is None:
153+
return
154+
155+
try:
156+
self.context = self.context_class.select_context(contexts)
157+
except BaseContextError as e:
158+
raise SuitableContextNotProvidedError(str(self), self.context_class.__name__) from e

src/dbally/views/freeform/text2sql/view.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dbally.audit.event_tracker import EventTracker
1010
from dbally.collection.results import ViewExecutionResult
11-
from dbally.context.context import CustomContext
11+
from dbally.context.context import BaseCallerContext
1212
from dbally.llms.base import LLM
1313
from dbally.llms.clients.base import LLMOptions
1414
from dbally.prompt.template import PromptTemplate
@@ -104,7 +104,7 @@ async def ask(
104104
n_retries: int = 3,
105105
dry_run: bool = False,
106106
llm_options: Optional[LLMOptions] = None,
107-
contexts: Optional[Iterable[CustomContext]] = None,
107+
contexts: Optional[Iterable[BaseCallerContext]] = None,
108108
) -> ViewExecutionResult:
109109
"""
110110
Executes the query and returns the result. It generates the SQL query from the natural language query and

0 commit comments

Comments
 (0)