Skip to content

Commit db3b53c

Browse files
Adding initial changes for aggregation example in quickstart code.
1 parent 9fd817f commit db3b53c

File tree

8 files changed

+102
-9
lines changed

8 files changed

+102
-9
lines changed

docs/quickstart/quickstart_code.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
2+
from typing import Union, Tuple, Any
3+
24
import dbally
35
import asyncio
46

@@ -54,14 +56,20 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement:
5456
"""
5557
return Candidate.country == country
5658

59+
@decorators.view_aggregation()
60+
def group_by_university(self, aggregation:str): # -> Union[Select[Tuple[Any, Any]], Select]: # pylint: disable=W0602, C0116, W9011
61+
return sqlalchemy.select(Candidate.university, sqlalchemy.func.count(Candidate.university).label("count")) \
62+
.group_by(Candidate.university)
63+
5764

5865
async def main():
5966
llm = LiteLLM(model_name="gpt-3.5-turbo")
6067

6168
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
6269
collection.add(CandidateView, lambda: CandidateView(engine))
6370

64-
result = await collection.ask("Find me French candidates suitable for a senior data scientist position.")
71+
# result = await collection.ask("Find me French candidates suitable for a senior data scientist position.")
72+
result = await collection.ask("Could you count the candidates university-wise and present the rows?")
6573

6674
print(f"The generated SQL query is: {result.context.get('sql')}")
6775
print()

src/dbally/iql_generator/iql_generator.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ def __init__(
3838
"""
3939
self._llm = llm
4040
self._prompt_template = prompt_template or copy.deepcopy(default_iql_template)
41-
self._promptify_view = promptify_view or _promptify_filters
41+
self._promptify_view = promptify_view or _promptify_filters or _promptify_aggregations
4242

4343
async def generate_iql(
4444
self,
4545
filters: List[ExposedFunction],
46+
aggregations: List[ExposedFunction],
4647
question: str,
4748
event_tracker: EventTracker,
4849
conversation: Optional[IQLPromptTemplate] = None,
@@ -62,12 +63,14 @@ async def generate_iql(
6263
IQL - iql generated based on the user question
6364
"""
6465
filters_for_prompt = self._promptify_view(filters)
66+
aggregations_for_prompt = self._promptify_view(aggregations)
6567

6668
template = conversation or self._prompt_template
6769

6870
llm_response = await self._llm.generate_text(
6971
template=template,
70-
fmt={"filters": filters_for_prompt, "question": question},
72+
fmt={"filters": filters_for_prompt, "question": question,
73+
"aggregation": aggregations_for_prompt},
7174
event_tracker=event_tracker,
7275
options=llm_options,
7376
)
@@ -114,3 +117,19 @@ def _promptify_filters(
114117
"""
115118
filters_for_prompt = "\n".join([str(filter) for filter in filters])
116119
return filters_for_prompt
120+
121+
122+
def _promptify_aggregations(
123+
aggregations: List[ExposedFunction],
124+
) -> str:
125+
"""
126+
Formats filters for prompt
127+
128+
Args:
129+
filters: list of filters exposed by the view
130+
131+
Returns:
132+
filters_for_prompt: filters formatted for prompt
133+
"""
134+
aggregations_for_prompt = "\n".join([str(aggregation) for aggregation in aggregations])
135+
return aggregations_for_prompt

src/dbally/nl_responder/nl_responder_prompt_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
"""
2525

2626
super().__init__(chat, response_format, llm_response_parser)
27-
self.chat = check_prompt_variables(chat, {"rows", "question"})
27+
self.chat = check_prompt_variables(chat, {"rows", "question", "aggregation"})
2828

2929

3030
default_nl_responder_template = NLResponderPromptTemplate(
@@ -34,7 +34,7 @@ def __init__(
3434
"content": "You are a helpful assistant that helps answer the user's questions "
3535
"based on the table provided. You MUST use the table to answer the question. "
3636
"You are very intelligent and obedient.\n"
37-
"The table ALWAYS contains full answer to a question.\n"
37+
"The table ALWAYS contains full answer to a question including necessary {aggregation}.\n"
3838
"Answer the question in a way that is easy to understand and informative.\n"
3939
"DON'T MENTION using a table in your answer.",
4040
},

src/dbally/nl_responder/query_explainer_prompt_template.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
llm_response_parser: Callable = lambda x: x,
2222
) -> None:
2323
super().__init__(chat, response_format, llm_response_parser)
24-
self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"})
24+
self.chat = check_prompt_variables(chat, {"question", "query", "aggregation", "number_of_results"})
2525

2626

2727
default_query_explainer_template = QueryExplainerPromptTemplate(
@@ -34,14 +34,14 @@ def __init__(
3434
"Your task is to provide natural language description of the table used by the logical query "
3535
"to the database.\n"
3636
"Describe the table in a way that is short and informative.\n"
37-
"Make your answer as short as possible, start it by infroming the user that the underlying "
37+
"Make your answer as short as possible, start it by informing the user that the underlying "
3838
"data is too long to print and then describe the table based on the question and the query.\n"
3939
"DON'T MENTION using a query in your answer.\n",
4040
},
4141
{
4242
"role": "user",
4343
"content": "The query below represents the answer to a question: {question}.\n"
44-
"Describe the table generated using this query: {query}.\n"
44+
"Describe the table generated using this query: {query} which applies {aggregation}.\n"
4545
"Number of results to this query: {number_of_results}.\n",
4646
},
4747
)

src/dbally/view_selection/llm_view_selector.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
"""
3636
self._llm = llm
3737
self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template)
38-
self._promptify_views = promptify_views or _promptify_views
38+
self._promptify_views = promptify_views or _promptify_views or _promptify_aggregations
3939

4040
async def select_view(
4141
self,
@@ -81,3 +81,17 @@ def _promptify_views(views: Dict[str, str]) -> str:
8181
"""
8282

8383
return "\n".join([f"{name}: {description}" for name, description in views.items()])
84+
85+
86+
def _promptify_aggregations(views: Dict[str, str]) -> str:
87+
"""
88+
Formats views for aggregation
89+
90+
Args:
91+
views: dictionary of available view names with corresponding descriptions.
92+
93+
Returns:
94+
views_for_prompt: views formatted for prompt
95+
"""
96+
97+
return "\n".join([f"{name}: {description}" for name, description in views.items()])

src/dbally/views/decorators.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,17 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin
1414
return func
1515

1616
return wrapped
17+
18+
def view_aggregation() -> typing.Callable:
19+
"""
20+
Decorator for marking a method as an aggregation
21+
22+
Returns:
23+
Function that returns the decorated method
24+
"""
25+
26+
def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc
27+
func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access
28+
return func
29+
30+
return wrapped

src/dbally/views/sqlalchemy_base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,26 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu
6464
return alchemy_op(await self._build_filter_node(bool_op.child))
6565
raise ValueError(f"BoolOp {bool_op} has no children")
6666

67+
async def _build_aggregation_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement:
68+
"""
69+
Converts a filter node from the IQLQuery to a SQLAlchemy expression.
70+
"""
71+
if isinstance(node, syntax.BoolOp):
72+
return await self._build_filter_bool_op(node)
73+
if isinstance(node, syntax.FunctionCall):
74+
return await self.call_filter_method(node)
75+
76+
raise ValueError(f"Unsupported grammar: {node}")
77+
78+
async def apply_aggregation(self, aggregation: IQLQuery) -> None:
79+
"""
80+
Applies the chosen aggregation to the view.
81+
82+
Args:
83+
aggregation: IQLQuery object representing the aggregation to apply
84+
"""
85+
self._select = self._select.where(await self._build_filter_node(aggregation.root))
86+
6787
def execute(self, dry_run: bool = False) -> ViewExecutionResult:
6888
"""
6989
Executes the generated SQL query and returns the results.

src/dbally/views/structured.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def ask(
5858
"""
5959
iql_generator = self.get_iql_generator(llm)
6060
filter_list = self.list_filters()
61+
aggregation_list = self.list_aggregations()
6162

6263
iql_filters, conversation = await iql_generator.generate_iql(
6364
question=query,
@@ -104,6 +105,23 @@ async def apply_filters(self, filters: IQLQuery) -> None:
104105
filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply
105106
"""
106107

108+
@abc.abstractmethod
109+
def list_aggregations(self) -> List[ExposedFunction]:
110+
"""
111+
112+
Returns:
113+
Aggregations defined inside the View.
114+
"""
115+
116+
@abc.abstractmethod
117+
async def apply_aggregations(self, filters: IQLQuery) -> None:
118+
"""
119+
Applies the chosen filters to the view.
120+
121+
Args:
122+
filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply
123+
"""
124+
107125
@abc.abstractmethod
108126
def execute(self, dry_run: bool = False) -> ViewExecutionResult:
109127
"""

0 commit comments

Comments
 (0)