Skip to content

Commit afacf5b

Browse files
additional unit tests for the new contextualization mechanism
1 parent 623effd commit afacf5b

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

tests/unit/views/test_methods_base.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
11
# pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name
22

3-
4-
from typing import List, Literal, Tuple
3+
import asyncio
4+
from dataclasses import dataclass
5+
from typing import List, Literal, Tuple, Union, Optional
56

67
from dbally.collection.results import ViewExecutionResult
78
from dbally.iql import IQLQuery
89
from dbally.views.decorators import view_filter
9-
from dbally.views.exposed_functions import MethodParamWithTyping
10+
from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction
1011
from dbally.views.methods_base import MethodsBaseView
12+
from dbally.context import BaseCallerContext
13+
from dbally.iql_generator.iql_generator import IQLGenerator
14+
from dbally.audit.event_tracker import EventTracker
15+
from dbally.prompt.elements import FewShotExample
16+
from dbally.llms.clients.base import LLMOptions
17+
from dbally.llms.base import LLM
18+
19+
20+
@dataclass
21+
class TestCallerContext(BaseCallerContext):
22+
"""
23+
Mock class for testing context.
24+
"""
25+
current_year: Literal['2023', '2024']
1126

1227

1328
class MockMethodsBase(MethodsBaseView):
@@ -22,7 +37,7 @@ def method_foo(self, idx: int) -> None:
2237
"""
2338

2439
@view_filter()
25-
def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: List[Tuple[str, int]]) -> str:
40+
def method_bar(self, cities: List[str], year: Union[Literal["2023", "2024"], TestCallerContext], pairs: List[Tuple[str, int]]) -> str:
2641
return f"hello {cities} in {year} of {pairs}"
2742

2843
async def apply_filters(self, filters: IQLQuery) -> None:
@@ -47,9 +62,24 @@ def test_list_filters() -> None:
4762
assert method_bar.description == ""
4863
assert method_bar.parameters == [
4964
MethodParamWithTyping("cities", List[str]),
50-
MethodParamWithTyping("year", Literal["2023", "2024"]),
65+
MethodParamWithTyping("year", Union[Literal["2023", "2024"], TestCallerContext]),
5166
MethodParamWithTyping("pairs", List[Tuple[str, int]]),
5267
]
5368
assert (
54-
str(method_bar) == "method_bar(cities: List[str], year: Literal['2023', '2024'], pairs: List[Tuple[str, int]])"
69+
str(method_bar) == "method_bar(cities: List[str], year: Literal['2023', '2024'] | AskerContext, pairs: List[Tuple[str, int]])"
5570
)
71+
72+
73+
async def test_contextualization() -> None:
74+
mock_view = MockMethodsBase()
75+
filters = mock_view.list_filters()
76+
test_context = TestCallerContext("2024")
77+
mock_view.contextualize_filters(filters, [test_context])
78+
79+
method_foo = [f for f in filters if f.name == "method_foo"][0]
80+
assert method_foo.context_class is None
81+
assert method_foo.context is None
82+
83+
method_bar = [f for f in filters if f.name == "method_bar"][0]
84+
assert method_bar.context_class is TestCallerContext
85+
assert method_bar.context is test_context

0 commit comments

Comments
 (0)