Skip to content

Commit ec6649f

Browse files
authored
evals docs (#100)
1 parent d28dee8 commit ec6649f

File tree

2 files changed

+200
-9
lines changed

2 files changed

+200
-9
lines changed

docs/testing-evals.md

Lines changed: 180 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,190 @@ async def test_forecast(override_weather_agent: None):
266266

267267
## Evals
268268

269-
"Evals" refers to evaluating the performance of an LLM when used in a specific context.
269+
"Evals" refers to evaluating a models performance for a specific application.
270270

271-
Unlike unit tests, evals are an emerging art/science, anyone who tells you they know exactly how evals should be defined can safely be ignored.
271+
!!! danger "Warning"
272+
Unlike unit tests, evals are an emerging art/science; anyone who claims to know for sure exactly how your evals should be defined can safely be ignored.
272273

273274
Evals are generally more like benchmarks than unit tests, they never "pass" although they do "fail"; you care mostly about how they change over time.
274275

276+
Since evals need to be run against the real model, then can be slow and expensive to run, you generally won't want to run them in CI for every commit.
277+
278+
### Measuring performance
279+
280+
The hardest part of evals is measuring how well the model has performed.
281+
282+
In some cases (e.g. an agent to generate SQL) there are simple, easy to run tests that can be used to measure performance (e.g. is the SQL valid? Does it return the right results? Does it return just the right results?).
283+
284+
In other cases (e.g. an agent that gives advice on quitting smoking) it can be very hard or impossible to make quantitative measures of performance — in the smoking case you'd really need to run a double-blind trial over months, then wait 40 years and observe health outcomes to know if changes to your prompt were an improvement.
285+
286+
There are a few different strategies you can use to measure performance:
287+
288+
* **End to end, self-contained tests** — like the SQL example, we can test the final result of the agent near-instantly
289+
* **Synthetic self-contained tests** — writing unit test style checks that the output is as expected, checks like `#!python 'chewing gum' in response`, while these checks might seem simplistic they can be helpful, one nice characteristic is that it's easy to tell what's wrong when they fail
290+
* **LLMs evaluating LLMs** — using another models, or even the same model with a different prompt to evaluate the performance of the agent (like when the class marks each other's homework because the teacher has a hangover), while the downsides and complexities of this approach are obvious, some think it can be a useful tool in the right circumstances
291+
* **Evals in prod** — measuring the end results of the agent in production, then creating a quantitative measure of performance, so you can easily measure changes over time as you change the prompt or model used, [logfire](logfire.md) can be extremely useful in this case since you can write a custom query to measure the performance of your agent
292+
275293
### System prompt customization
276294

277-
The system prompt is the developer's primary tool in controlling the LLM's behavior, so it's often useful to be able to customise the system prompt and see how performance changes.
295+
The system prompt is the developer's primary tool in controlling an agent's behavior, so it's often useful to be able to customise the system prompt and see how performance changes. This is particularly relevant when the system prompt contains a list of examples and you want to understand how changing that list affects the model's performance.
296+
297+
Let's assume we have the following app for running SQL generated from a user prompt (this examples omits a lot of details for brevity, see the [SQL gen](examples/sql-gen.md) example for a more complete code):
298+
299+
```py title="sql_app.py"
300+
import json
301+
from pathlib import Path
302+
from typing import Union
303+
304+
from pydantic_ai import Agent, CallContext
305+
306+
from fake_database import DatabaseConn
307+
308+
309+
class SqlSystemPrompt: # (1)!
310+
def __init__(
311+
self, examples: Union[list[dict[str, str]], None] = None, db: str = 'PostgreSQL'
312+
):
313+
if examples is None:
314+
# if examples aren't provided, load them from file, this is the default
315+
with Path('examples.json').open('rb') as f:
316+
self.examples = json.load(f)
317+
else:
318+
self.examples = examples
319+
320+
self.db = db
321+
322+
def build_prompt(self) -> str: # (2)!
323+
return f"""\
324+
Given the following {self.db} table of records, your job is to
325+
write a SQL query that suits the user's request.
326+
327+
Database schema:
328+
CREATE TABLE records (
329+
...
330+
);
331+
332+
{''.join(self.format_example(example) for example in self.examples)}
333+
"""
334+
335+
@staticmethod
336+
def format_example(example: dict[str, str]) -> str: # (3)!
337+
return f"""\
338+
<example>
339+
<request>{example['request']}</request>
340+
<sql>{example['sql']}</sql>
341+
</example>
342+
"""
343+
344+
345+
sql_agent = Agent(
346+
'gemini-1.5-flash',
347+
deps_type=SqlSystemPrompt,
348+
)
349+
350+
351+
@sql_agent.system_prompt
352+
async def system_prompt(ctx: CallContext[SqlSystemPrompt]) -> str:
353+
return ctx.deps.build_prompt()
354+
355+
356+
async def user_search(user_prompt: str) -> list[dict[str, str]]:
357+
"""Search the database based on the user's prompts."""
358+
... # (4)!
359+
result = await sql_agent.run(user_prompt, deps=SqlSystemPrompt())
360+
conn = DatabaseConn()
361+
return await conn.execute(result.data)
362+
```
363+
364+
`examples.json` looks something like this:
365+
366+
367+
request: show me error records with the tag "foobar"
368+
response: SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)
369+
370+
```json title="examples.json"
371+
{
372+
"examples": [
373+
{
374+
"request": "Show me all records",
375+
"sql": "SELECT * FROM records;"
376+
},
377+
{
378+
"request": "Show me all records from 2021",
379+
"sql": "SELECT * FROM records WHERE date_trunc('year', date) = '2021-01-01';"
380+
},
381+
{
382+
"request": "show me error records with the tag 'foobar'",
383+
"sql": "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags);"
384+
},
385+
...
386+
]
387+
}
388+
```
389+
390+
Now we want a way to quantify the success of the SQL generation so we can judge how changes to the agent affect its performance.
391+
392+
We can use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace the system prompt with a custom one that uses a subset of examples, and then run the application code (in this case `user_search`). We also run the actual SQL from the examples and compare the "correct" result from the example SQL to the SQL generated by the agent. (We compare the results of running the SQL rather than the SQL itself since the SQL might be semantically equivalent but written in a different way).
393+
394+
To get a quantitative measure of performance, we assign points to each run as follows:
395+
* **-100** points if the generated SQL is invalid
396+
* **-1** point for each row returned by the agent (so returning lots of results is discouraged)
397+
* **+5** points for each row returned by the agent that matches the expected result
398+
399+
We use 5-fold cross-validation to judge the performance of the agent using our existing set of examples.
400+
401+
```py title="test_sql_app.py"
402+
import json
403+
import statistics
404+
from pathlib import Path
405+
from itertools import chain
406+
407+
from fake_database import DatabaseConn, QueryError
408+
from sql_app import sql_agent, SqlSystemPrompt, user_search
409+
410+
411+
async def main():
412+
with Path('examples.json').open('rb') as f:
413+
examples = json.load(f)
414+
415+
# split examples into 5 folds
416+
fold_size = len(examples) // 5
417+
folds = [examples[i : i + fold_size] for i in range(0, len(examples), fold_size)]
418+
conn = DatabaseConn()
419+
scores = []
420+
421+
for i, fold in enumerate(folds, start=1):
422+
fold_score = 0
423+
# build all other folds into a list of examples
424+
other_folds = list(chain(*(f for j, f in enumerate(folds) if j != i)))
425+
# create a new system prompt with the other fold examples
426+
system_prompt = SqlSystemPrompt(examples=other_folds)
427+
428+
# override the system prompt with the new one
429+
with sql_agent.override(deps=system_prompt):
430+
for case in fold:
431+
try:
432+
agent_results = await user_search(case['request'])
433+
except QueryError as e:
434+
print(f'Fold {i} {case}: {e}')
435+
fold_score -= 100
436+
else:
437+
# get the expected results using the SQL from this case
438+
expected_results = await conn.execute(case['sql'])
439+
440+
agent_ids = [r['id'] for r in agent_results]
441+
# each returned value has a score of -1
442+
fold_score -= len(agent_ids)
443+
expected_ids = {r['id'] for r in expected_results}
444+
445+
# each return value that matches the expected value has a score of 3
446+
fold_score += 5 * len(set(agent_ids) & expected_ids)
447+
448+
scores.append(fold_score)
449+
450+
overall_score = statistics.mean(scores)
451+
print(f'Overall score: {overall_score:0.2f}')
452+
#> Overall score: 12.00
453+
```
278454

279-
TODO example of customizing system prompt through deps.
455+
We can then change the prompt, the model, or the examples and see how the score changes over time.

tests/test_examples.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations as _annotations
22

3+
import json
4+
import os
35
import re
46
import sys
57
from collections.abc import AsyncIterator, Iterable
68
from dataclasses import dataclass, field
79
from datetime import date
10+
from pathlib import Path
811
from types import ModuleType
912
from typing import Any
1013

@@ -51,8 +54,8 @@ class DatabaseConn:
5154
users: FakeTable = field(default_factory=FakeTable)
5255
_forecasts: dict[int, str] = field(default_factory=dict)
5356

54-
async def execute(self, query: str) -> None:
55-
pass
57+
async def execute(self, query: str) -> list[dict[str, Any]]:
58+
return [{'id': 123, 'name': 'John Doe'}]
5659

5760
async def store_forecast(self, user_id: int, forecast: str) -> None:
5861
self._forecasts[user_id] = forecast
@@ -129,6 +132,7 @@ def test_docs_examples(
129132
mocker: MockerFixture,
130133
client_with_handler: ClientWithHandler,
131134
env: TestEnv,
135+
tmp_path: Path,
132136
):
133137
mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model)
134138
mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal)
@@ -145,6 +149,14 @@ def test_docs_examples(
145149
env.set('GROQ_API_KEY', 'testing')
146150

147151
prefix_settings = example.prefix_settings()
152+
opt_title = prefix_settings.get('title')
153+
cwd = Path.cwd()
154+
155+
if opt_title == 'test_sql_app.py':
156+
os.chdir(tmp_path)
157+
examples = [{'request': f'sql prompt {i}', 'sql': f'SELECT {i}'} for i in range(15)]
158+
with (tmp_path / 'examples.json').open('w') as f:
159+
json.dump(examples, f)
148160

149161
ruff_ignore: list[str] = ['D']
150162
# `from bank_database import DatabaseConn` wrongly sorted in imports
@@ -153,7 +165,7 @@ def test_docs_examples(
153165
ruff_ignore.append('I001')
154166

155167
line_length = 88
156-
if prefix_settings.get('title') in ('streamed_hello_world.py', 'streamed_user_profile.py'):
168+
if opt_title in ('streamed_hello_world.py', 'streamed_user_profile.py'):
157169
line_length = 120
158170

159171
eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39', line_length=line_length)
@@ -173,8 +185,8 @@ def test_docs_examples(
173185
eval_example.lint(example)
174186
module_dict = eval_example.run_print_check(example, call=call_name)
175187

176-
debug(prefix_settings)
177-
if title := prefix_settings.get('title'):
188+
os.chdir(cwd)
189+
if title := opt_title:
178190
if title.endswith('.py'):
179191
module_name = title[:-3]
180192
sys.modules[module_name] = module = ModuleType(module_name)
@@ -275,6 +287,9 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo
275287
else:
276288
return ModelStructuredResponse(calls=[response])
277289

290+
if re.fullmatch(r'sql prompt \d+', m.content):
291+
return ModelTextResponse(content='SELECT 1')
292+
278293
elif m.role == 'tool-return' and m.tool_name == 'roulette_wheel':
279294
win = m.content == 'winner'
280295
return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsDict({'response': win}))])

0 commit comments

Comments
 (0)