Skip to content

Commit dd8b339

Browse files
context benchmark script and data
1 parent afacf5b commit dd8b339

File tree

2 files changed

+297
-0
lines changed

2 files changed

+297
-0
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
2+
import dbally
3+
import asyncio
4+
import typing
5+
import json
6+
import traceback
7+
import os
8+
9+
import tqdm.asyncio
10+
import sqlalchemy
11+
import pydantic
12+
from typing_extensions import TypeAlias
13+
from copy import deepcopy
14+
from sqlalchemy import create_engine
15+
from sqlalchemy.ext.automap import automap_base, AutomapBase
16+
from dataclasses import dataclass, field
17+
18+
from dbally import decorators, SqlAlchemyBaseView
19+
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
20+
from dbally.llms.litellm import LiteLLM
21+
from dbally.context import BaseCallerContext
22+
23+
24+
SQLITE_DB_FILE_REL_PATH = "../../examples/recruiting/data/candidates.db"
25+
engine = create_engine(f"sqlite:///{os.path.abspath(SQLITE_DB_FILE_REL_PATH)}")
26+
27+
Base: AutomapBase = automap_base()
28+
Base.prepare(autoload_with=engine)
29+
30+
Candidate = Base.classes.candidates
31+
32+
33+
class MyData(BaseCallerContext, pydantic.BaseModel):
34+
first_name: str
35+
surname: str
36+
position: str
37+
years_of_experience: int
38+
university: str
39+
skills: typing.List[str]
40+
country: str
41+
42+
43+
class OpenPosition(BaseCallerContext, pydantic.BaseModel):
44+
position: str
45+
min_years_of_experience: int
46+
graduated_from_university: str
47+
required_skills: typing.List[str]
48+
49+
50+
class CandidateView(SqlAlchemyBaseView):
51+
"""
52+
A view for retrieving candidates from the database.
53+
"""
54+
55+
def get_select(self) -> sqlalchemy.Select:
56+
"""
57+
Creates the initial SqlAlchemy select object, which will be used to build the query.
58+
"""
59+
return sqlalchemy.select(Candidate)
60+
61+
@decorators.view_filter()
62+
def at_least_experience(self, years: typing.Union[int, OpenPosition]) -> sqlalchemy.ColumnElement:
63+
"""
64+
Filters candidates with at least `years` of experience.
65+
"""
66+
if isinstance(years, OpenPosition):
67+
years = years.min_years_of_experience
68+
69+
return Candidate.years_of_experience >= years
70+
71+
@decorators.view_filter()
72+
def at_most_experience(self, years: typing.Union[int, MyData]) -> sqlalchemy.ColumnElement:
73+
if isinstance(years, MyData):
74+
years = years.years_of_experience
75+
76+
return Candidate.years_of_experience <= years
77+
78+
@decorators.view_filter()
79+
def has_position(self, position: typing.Union[str, OpenPosition]) -> sqlalchemy.ColumnElement:
80+
if isinstance(position, OpenPosition):
81+
position = position.position
82+
83+
return Candidate.position == position
84+
85+
@decorators.view_filter()
86+
def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement:
87+
"""
88+
Filters candidates that can be considered for a senior data scientist position.
89+
"""
90+
return sqlalchemy.and_(
91+
Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]),
92+
Candidate.years_of_experience >= 3,
93+
)
94+
95+
@decorators.view_filter()
96+
def from_country(self, country: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement:
97+
"""
98+
Filters candidates from a specific country.
99+
"""
100+
if isinstance(country, MyData):
101+
return Candidate.country == country.country
102+
103+
return Candidate.country == country
104+
105+
@decorators.view_filter()
106+
def graduated_from_university(self, university: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement:
107+
if isinstance(university, MyData):
108+
university = university.university
109+
110+
return Candidate.university == university
111+
112+
@decorators.view_filter()
113+
def has_skill(self, skill: str) -> sqlalchemy.ColumnElement:
114+
return Candidate.skills.like(f"%{skill}%")
115+
116+
@decorators.view_filter()
117+
def knows_data_analysis(self) -> sqlalchemy.ColumnElement:
118+
return Candidate.tags.like("%Data Analysis%")
119+
120+
@decorators.view_filter()
121+
def knows_python(self) -> sqlalchemy.ColumnElement:
122+
return Candidate.skills.like("%Python%")
123+
124+
@decorators.view_filter()
125+
def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement:
126+
if isinstance(first_name, MyData):
127+
first_name = first_name.first_name
128+
129+
return Candidate.name.startswith(first_name)
130+
131+
132+
OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o']
133+
134+
135+
def setup_collection(model_name: OpenAILLMName) -> dbally.Collection:
136+
llm = LiteLLM(model_name=model_name)
137+
138+
collection = dbally.create_collection("recruitment", llm)
139+
collection.add(CandidateView, lambda: CandidateView(engine))
140+
141+
return collection
142+
143+
144+
async def generate_iql_from_question(
145+
collection: dbally.Collection,
146+
model_name: OpenAILLMName,
147+
question: str,
148+
contexts: typing.Optional[typing.List[BaseCallerContext]]
149+
) -> typing.Tuple[str, OpenAILLMName, typing.Optional[str]]:
150+
151+
try:
152+
result = await collection.ask(
153+
question,
154+
contexts=contexts,
155+
dry_run=True
156+
)
157+
except Exception as e:
158+
exc_pretty = traceback.format_exception_only(e.__class__, e)[0]
159+
return question, model_name, f"FAILED: {exc_pretty}"
160+
161+
out = result.metadata.get("iql")
162+
if out is None:
163+
return question, model_name, None
164+
165+
return question, model_name, out.replace('"', '\'')
166+
167+
168+
@dataclass
169+
class BenchmarkConfig:
170+
dataset_path: str
171+
out_path: str
172+
n_repeats: int = 5
173+
llms: typing.List[OpenAILLMName] = field(default_factory=lambda: ['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o'])
174+
175+
176+
async def main(config: BenchmarkConfig):
177+
test_set = None
178+
with open(config.dataset_path, 'r') as file:
179+
test_set = json.load(file)
180+
181+
contexts = [
182+
MyData(
183+
first_name="John",
184+
surname="Smith",
185+
years_of_experience=4,
186+
position="Data Engineer",
187+
university="University of Toronto",
188+
skills=["Python"],
189+
country="United Kingdom"
190+
),
191+
OpenPosition(
192+
position="Machine Learning Engineer",
193+
graduated_from_university="Stanford Univeristy",
194+
min_years_of_experience=1,
195+
required_skills=["Python", "SQL"]
196+
)
197+
]
198+
199+
tasks: typing.List[asyncio.Task] = []
200+
for model_name in config.llms:
201+
collection = setup_collection(model_name)
202+
for test_case in test_set:
203+
answers = []
204+
for _ in range(config.n_repeats):
205+
task = asyncio.create_task(generate_iql_from_question(collection, model_name,
206+
test_case["question"], contexts=contexts))
207+
tasks.append(task)
208+
209+
output_data = {
210+
test_case["question"]:test_case
211+
for test_case in test_set
212+
}
213+
empty_answers = {str(llm_name): [] for llm_name in config.llms}
214+
215+
total_iter = len(config.llms) * len(test_set) * config.n_repeats
216+
for task in tqdm.asyncio.tqdm.as_completed(tasks, total=total_iter):
217+
question, llm_name, answer = await task
218+
if "answers" not in output_data[question]:
219+
output_data[question]["answers"] = deepcopy(empty_answers)
220+
221+
output_data[question]["answers"][llm_name].append(answer)
222+
223+
output_data_list = list(output_data.values())
224+
225+
with open(config.out_path, 'w') as file:
226+
file.write(json.dumps(test_set, indent=2))
227+
228+
229+
if __name__ == "__main__":
230+
config = BenchmarkConfig(
231+
dataset_path="dataset/context_dataset.json",
232+
out_path="../../context_benchmark_output.json"
233+
)
234+
235+
asyncio.run(main(config))
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
[
2+
{
3+
"question": "Find me French candidates suitable for my position with at least 1 year of experience.",
4+
"correct_answer": "from_country('France') AND has_position(AskerContext()) AND at_least_experience(1)",
5+
"context": false
6+
},
7+
{
8+
"question": "Please find me candidates from my country who have at most 4 years of experience.",
9+
"correct_answer": "from_country(AskerContext()) AND at_most_experience(4)",
10+
"context": true
11+
},
12+
{
13+
"question": "Find me candidates who graduated from Stanford University and work as Software Engineers.",
14+
"correct_answer": "graduated_from_university('Stanford University') AND has_position('Software Engineer')",
15+
"context": false
16+
},
17+
{
18+
"question": "Find me candidates who graduated from my university",
19+
"correct_answer": "graduated_from_university(AskerContext())",
20+
"context": true
21+
},
22+
{
23+
"question": "Could you find me candidates with at most as experience who also know Python?",
24+
"correct_answer": "at_most_experience(AskerContext()) AND know_python()",
25+
"context": true
26+
},
27+
{
28+
"question": "Please find me candidates who know Data Analysis and Python",
29+
"correct_answer": "know_python() AND know_data_analysis()",
30+
"context": false
31+
},
32+
{
33+
"question": "Find me candidates with at least minimal required experience for the currently open position.",
34+
"correct_answer": "at_least_experience(AskerContext())",
35+
"context": true
36+
},
37+
{
38+
"question": "List candidates with between 2 and 6 years of experience.",
39+
"correct_answer": "at_least_experience(2) AND at_most_experience(6)",
40+
"context": false
41+
},
42+
{
43+
"question": "Find me candidates who currently have the same position as we look for in our company?",
44+
"correct_answer": "has_position(AskerContext())",
45+
"context": true
46+
},
47+
{
48+
"question": "Please find me senior data scientist candidates who know Data Analysis and come from my country",
49+
"correct_answer": "senior_data_scientist_position() AND has_skill('Data Analysis') AND from_country(AskerContext())",
50+
"context": true
51+
},
52+
{
53+
"question": "Find me candidates that have the same first name as me",
54+
"correct_answer": "first_name_is(AskerContext())",
55+
"context": true
56+
},
57+
{
58+
"question": "List candidates named Mohammed from India",
59+
"correct_answer": "first_name_is('Mohammed') AND from_country('India')",
60+
"context": false
61+
}
62+
]

0 commit comments

Comments
 (0)