Skip to content

Commit d482638

Browse files
authored
feat: few-shot selector (#42)
1 parent cd5bf7b commit d482638

File tree

13 files changed

+435
-62
lines changed

13 files changed

+435
-62
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ omit =
1212
exclude_lines =
1313
pragma: no cover
1414
if __name__ == .__main__.
15+
\.\.\.
1516
show_missing = True

benchmark/dbally_benchmark/iql_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from dbally.iql_generator.iql_generator import IQLGenerator
2424
from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template
2525
from dbally.llms.litellm import LiteLLM
26+
from dbally.prompts.formatters import IQLInputFormatter
2627
from dbally.views.structured import BaseStructuredView
2728

2829

@@ -31,11 +32,10 @@ async def _run_iql_for_single_example(
3132
) -> IQLResult:
3233
filter_list = view.list_filters()
3334
event_tracker = EventTracker()
35+
input_formatter = IQLInputFormatter(question=example.question, filters=filter_list)
3436

3537
try:
36-
iql_filters, _ = await iql_generator.generate_iql(
37-
question=example.question, filters=filter_list, event_tracker=event_tracker
38-
)
38+
iql_filters, _ = await iql_generator.generate_iql(input_formatter=input_formatter, event_tracker=event_tracker)
3939
except UnsupportedQueryError:
4040
return IQLResult(question=example.question, iql_filters="UNSUPPORTED_QUERY", exception_raised=True)
4141

examples/recruiting/views.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import Literal
1+
from datetime import date
2+
from typing import List, Literal
23

34
import awoc # pip install a-world-of-countries
45
import sqlalchemy
6+
from dateutil.relativedelta import relativedelta
57
from sqlalchemy import and_, select
68

79
from dbally import SqlAlchemyBaseView, decorators
10+
from dbally.prompts.elements import FewShotExample
811

912
from .db import Candidate
1013

@@ -57,3 +60,43 @@ def is_from_continent( # pylint: disable=W0602, C0116, W9011
5760
@decorators.view_filter()
5861
def studied_at(self, university: str) -> sqlalchemy.ColumnElement: # pylint: disable=W0602, C0116, W9011
5962
return Candidate.university == university
63+
64+
65+
class FewShotRecruitmentView(RecruitmentView):
66+
"""
67+
A view for the recruitment database including examples of question:answers pairs (few-shot).
68+
"""
69+
70+
@decorators.view_filter()
71+
def is_available_within_months( # pylint: disable=W0602, C0116, W9011
72+
self, months: int
73+
) -> sqlalchemy.ColumnElement:
74+
start = date.today()
75+
end = start + relativedelta(months=months)
76+
return Candidate.available_from.between(start, end)
77+
78+
def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011
79+
return [
80+
FewShotExample(
81+
"Which candidates studied at University of Toronto?",
82+
'studied_at("University of Toronto")',
83+
),
84+
FewShotExample(
85+
"Do we have any soon available candidate?",
86+
lambda: self.is_available_within_months(1),
87+
),
88+
FewShotExample(
89+
"Do we have any soon available perfect fits for senior data scientist positions?",
90+
lambda: (
91+
self.is_available_within_months(1)
92+
and self.data_scientist_position()
93+
and self.has_seniority("senior")
94+
),
95+
),
96+
FewShotExample(
97+
"List all junior or senior data scientist positions",
98+
lambda: (
99+
self.data_scientist_position() and (self.has_seniority("junior") or self.has_seniority("senior"))
100+
),
101+
),
102+
]
Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import copy
2-
from typing import Callable, List, Optional, Tuple, TypeVar
1+
from typing import List, Optional, Tuple, TypeVar
32

43
from dbally.audit.event_tracker import EventTracker
5-
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
4+
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template # noqa
65
from dbally.llms.base import LLM
76
from dbally.llms.clients.base import LLMOptions
8-
from dbally.views.exposed_functions import ExposedFunction
7+
from dbally.prompts.formatters import IQLInputFormatter
98

109

1110
class IQLGenerator:
@@ -24,26 +23,16 @@ class IQLGenerator:
2423

2524
TException = TypeVar("TException", bound=Exception)
2625

27-
def __init__(
28-
self,
29-
llm: LLM,
30-
prompt_template: Optional[IQLPromptTemplate] = None,
31-
promptify_view: Optional[Callable] = None,
32-
) -> None:
26+
def __init__(self, llm: LLM) -> None:
3327
"""
3428
Args:
3529
llm: LLM used to generate IQL
36-
prompt_template: If not provided by the users is set to `default_iql_template`
37-
promptify_view: Function formatting filters for prompt
3830
"""
3931
self._llm = llm
40-
self._prompt_template = prompt_template or copy.deepcopy(default_iql_template)
41-
self._promptify_view = promptify_view or _promptify_filters
4232

4333
async def generate_iql(
4434
self,
45-
filters: List[ExposedFunction],
46-
question: str,
35+
input_formatter: IQLInputFormatter,
4736
event_tracker: EventTracker,
4837
conversation: Optional[IQLPromptTemplate] = None,
4938
llm_options: Optional[LLMOptions] = None,
@@ -52,30 +41,25 @@ async def generate_iql(
5241
Uses LLM to generate IQL in text form
5342
5443
Args:
55-
question: user question
56-
filters: list of filters exposed by the view
44+
input_formatter: formatter used to prepare prompt arguments dictionary
5745
event_tracker: event store used to audit the generation process
5846
conversation: conversation to be continued
5947
llm_options: options to use for the LLM client
6048
6149
Returns:
6250
IQL - iql generated based on the user question
6351
"""
64-
filters_for_prompt = self._promptify_view(filters)
6552

66-
template = conversation or self._prompt_template
53+
conversation, fmt = input_formatter(conversation or default_iql_template)
6754

6855
llm_response = await self._llm.generate_text(
69-
template=template,
70-
fmt={"filters": filters_for_prompt, "question": question},
56+
template=conversation,
57+
fmt=fmt,
7158
event_tracker=event_tracker,
7259
options=llm_options,
7360
)
7461

75-
iql_filters = self._prompt_template.llm_response_parser(llm_response)
76-
77-
if conversation is None:
78-
conversation = self._prompt_template
62+
iql_filters = conversation.llm_response_parser(llm_response)
7963

8064
conversation = conversation.add_assistant_message(content=llm_response)
8165

@@ -98,19 +82,3 @@ def add_error_msg(self, conversation: IQLPromptTemplate, errors: List[TException
9882
msg += str(error) + "\n"
9983

10084
return conversation.add_user_message(content=msg)
101-
102-
103-
def _promptify_filters(
104-
filters: List[ExposedFunction],
105-
) -> str:
106-
"""
107-
Formats filters for prompt
108-
109-
Args:
110-
filters: list of filters exposed by the view
111-
112-
Returns:
113-
filters_for_prompt: filters formatted for prompt
114-
"""
115-
filters_for_prompt = "\n".join([str(filter) for filter in filters])
116-
return filters_for_prompt

src/dbally/llms/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFo
5252
Returns:
5353
Prompt in the format of the client.
5454
"""
55-
return [{**message, "content": message["content"].format(**fmt)} for message in template.chat]
55+
return [{"role": message["role"], "content": message["content"].format(**fmt)} for message in template.chat]
5656

5757
def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int:
5858
"""

src/dbally/prompts/elements.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import inspect
2+
import re
3+
import textwrap
4+
from typing import Callable, Union
5+
6+
7+
class FewShotExample:
8+
"""
9+
A question:answer representation for few-shot prompting
10+
"""
11+
12+
def __init__(self, question: str, answer_expr: Union[str, Callable]) -> None:
13+
"""
14+
Args:
15+
question: sample question
16+
answer_expr: it can be either a stringified expression or a lambda for greater safety and code completions.
17+
18+
Raises:
19+
ValueError: If answer_expr is not a correct type.
20+
"""
21+
self.question = question
22+
self.answer_expr = answer_expr
23+
24+
if isinstance(self.answer_expr, str):
25+
self.answer = self.answer_expr
26+
elif callable(answer_expr):
27+
self.answer = self._parse_lambda(answer_expr)
28+
else:
29+
raise ValueError("Answer expression should be either a string or a lambda")
30+
31+
def _parse_lambda(self, expr: Callable) -> str:
32+
"""
33+
Parses provided callable in order to extract the lambda code.
34+
All comments and references to variables like `self` etc will be removed
35+
to form a simple lambda representation.
36+
37+
Args:
38+
expr: lambda expression to parse
39+
40+
Returns:
41+
Parsed lambda in a form of cleaned up string
42+
"""
43+
# extract lambda from code
44+
expr_source = textwrap.dedent(inspect.getsource(expr))
45+
expr_body = expr_source.replace("lambda:", "")
46+
47+
# clean up by removing comments, new lines, free vars (self etc)
48+
parsed_expr = re.sub("\\#.*\n", "\n", expr_body, flags=re.MULTILINE)
49+
50+
for m_name in expr.__code__.co_names:
51+
parsed_expr = parsed_expr.replace(f"{expr.__code__.co_freevars[0]}.{m_name}", m_name)
52+
53+
# clean up any dangling commas or leading and trailing brackets
54+
parsed_expr = " ".join(parsed_expr.split()).strip().rstrip(",").replace("( ", "(").replace(" )", ")")
55+
if parsed_expr.startswith("("):
56+
parsed_expr = parsed_expr[1:-1]
57+
58+
return parsed_expr
59+
60+
def __str__(self) -> str:
61+
return self.answer

src/dbally/prompts/formatters.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import copy
2+
from abc import ABCMeta, abstractmethod
3+
from typing import Dict, List, Tuple
4+
5+
from dbally.prompts.elements import FewShotExample
6+
from dbally.prompts.prompt_template import PromptTemplate
7+
from dbally.views.exposed_functions import ExposedFunction
8+
9+
10+
def _promptify_filters(
11+
filters: List[ExposedFunction],
12+
) -> str:
13+
"""
14+
Formats filters for prompt
15+
16+
Args:
17+
filters: list of filters exposed by the view
18+
19+
Returns:
20+
filters formatted for prompt
21+
"""
22+
filters_for_prompt = "\n".join([str(filter) for filter in filters])
23+
return filters_for_prompt
24+
25+
26+
class InputFormatter(metaclass=ABCMeta):
27+
"""
28+
Formats provided parameters to a form acceptable by IQL prompt
29+
"""
30+
31+
@abstractmethod
32+
def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]:
33+
"""
34+
Runs the input formatting for provided prompt template.
35+
36+
Args:
37+
conversation_template: a prompt template to use.
38+
39+
Returns:
40+
A tuple with template and a dictionary with formatted inputs.
41+
"""
42+
43+
44+
class IQLInputFormatter(InputFormatter):
45+
"""
46+
Formats provided parameters to a form acceptable by default IQL prompt
47+
"""
48+
49+
def __init__(self, filters: List[ExposedFunction], question: str) -> None:
50+
self.filters = filters
51+
self.question = question
52+
53+
def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]:
54+
"""
55+
Runs the input formatting for provided prompt template.
56+
57+
Args:
58+
conversation_template: a prompt template to use.
59+
60+
Returns:
61+
A tuple with template and a dictionary with formatted filters and a question.
62+
"""
63+
return conversation_template, {
64+
"filters": _promptify_filters(self.filters),
65+
"question": self.question,
66+
}
67+
68+
69+
class IQLFewShotInputFormatter(InputFormatter):
70+
"""
71+
Formats provided parameters to a form acceptable by default IQL prompt.
72+
Calling it will inject `examples` before last message in a conversation.
73+
"""
74+
75+
def __init__(
76+
self,
77+
filters: List[ExposedFunction],
78+
examples: List[FewShotExample],
79+
question: str,
80+
) -> None:
81+
self.filters = filters
82+
self.question = question
83+
self.examples = examples
84+
85+
def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]:
86+
"""
87+
Performs a deep copy of provided template and injects examples into chat history.
88+
Also prepares filters and question to be included within the prompt.
89+
90+
Args:
91+
conversation_template: a prompt template to use to inject few-shot examples.
92+
93+
Returns:
94+
A tuple with deeply-copied and enriched with examples template
95+
and a dictionary with formatted filters and a question.
96+
"""
97+
98+
template_copy = copy.deepcopy(conversation_template)
99+
sys_msg = template_copy.chat[0]
100+
existing_msgs = [msg for msg in template_copy.chat[1:] if "is_example" not in msg]
101+
chat_examples = [
102+
msg
103+
for example in self.examples
104+
for msg in [
105+
{"role": "user", "content": example.question, "is_example": True},
106+
{"role": "assistant", "content": example.answer, "is_example": True},
107+
]
108+
]
109+
110+
template_copy.chat = (
111+
sys_msg,
112+
*chat_examples,
113+
*existing_msgs,
114+
)
115+
116+
return template_copy, {
117+
"filters": _promptify_filters(self.filters),
118+
"question": self.question,
119+
}

src/dbally/views/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dbally.collection.results import ViewExecutionResult
66
from dbally.llms.base import LLM
77
from dbally.llms.clients.base import LLMOptions
8+
from dbally.prompts.elements import FewShotExample
89
from dbally.similarity import AbstractSimilarityIndex
910

1011
IndexLocation = Tuple[str, str, str]
@@ -49,3 +50,12 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc
4950
Mapping of similarity indexes to their locations.
5051
"""
5152
return {}
53+
54+
def list_few_shots(self) -> List[FewShotExample]:
55+
"""
56+
List all examples to be injected into few-shot prompt.
57+
58+
Returns:
59+
List of few-shot examples
60+
"""
61+
return []

0 commit comments

Comments
 (0)