Skip to content

Commit 16367be

Browse files
committed
docs: update ab testing cookbook
1 parent 7645443 commit 16367be

File tree

3 files changed

+51
-30
lines changed

3 files changed

+51
-30
lines changed

cookbook/ab_testing.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from openai import OpenAI
77

8-
from parea import Parea, get_current_trace_id, trace, trace_insert
9-
from parea.schemas import FeedbackRequest
8+
from parea import Parea, get_current_trace_id, parea_logger, trace, trace_insert
9+
from parea.schemas import EvaluationResult, UpdateLog
1010

1111
client = OpenAI()
1212
# instantiate Parea client
@@ -15,16 +15,27 @@
1515
p.wrap_openai_client(client)
1616

1717

18-
@trace
19-
def generate_email(user: str) -> Tuple[str, str]:
20-
"""Randomly chooses a prompt to perform an A/B test for generating email. Returns the email and the trace ID.
21-
The latter is used to tie-back the collected feedback from the user."""
18+
ab_test_name = "long-vs-short-emails"
19+
20+
21+
@trace # decorator to trace functions with Parea
22+
def generate_email(user: str) -> Tuple[str, str, str]:
23+
# randomly choose to generate a long or short email
2224
if random.random() < 0.5:
23-
trace_insert({"metadata": {"ab_test_0": "variant_0"}})
25+
variant = "variant_0"
2426
prompt = f"Generate a long email for {user}"
2527
else:
26-
trace_insert({"metadata": {"ab_test_0": "variant_1"}})
28+
variant = "variant_1"
2729
prompt = f"Generate a short email for {user}"
30+
# tag the requests with the A/B test name & chosen variant
31+
trace_insert(
32+
{
33+
"metadata": {
34+
"ab_test_name": ab_test_name,
35+
f"ab_test_{ab_test_name}": variant,
36+
}
37+
}
38+
)
2839

2940
email = (
3041
client.chat.completions.create(
@@ -39,22 +50,37 @@ def generate_email(user: str) -> Tuple[str, str]:
3950
.choices[0]
4051
.message.content
4152
)
53+
# need to return in addition to the email, the trace_id and the chosen variant
54+
return email, get_current_trace_id(), variant
4255

43-
return email, get_current_trace_id()
4456

57+
def capture_feedback(feedback: float, trace_id: str, ab_test_variant: str, user_corrected_email: str = None) -> None:
58+
field_name_to_value_map = {
59+
"scores": [EvaluationResult(name=f"ab_test_{ab_test_variant}", score=feedback, reason="any additional user feedback on why it's good/bad")],
60+
}
61+
if user_corrected_email:
62+
field_name_to_value_map["target"] = user_corrected_email
4563

46-
def main():
47-
# generate email and get trace ID
48-
email, trace_id = generate_email("Max Mustermann")
49-
50-
# log user feedback on email using trace ID
51-
p.record_feedback(
52-
FeedbackRequest(
64+
parea_logger.update_log(
65+
UpdateLog(
5366
trace_id=trace_id,
54-
score=1.0,
67+
field_name_to_value_map=field_name_to_value_map,
5568
)
5669
)
5770

5871

72+
def main():
73+
# generate email and get trace ID
74+
email, trace_id, ab_test_variant = generate_email("Max Mustermann")
75+
76+
# create a biased feedback for shorter emals
77+
if ab_test_variant == "variant_1":
78+
user_feedback = 0.0 if random.random() < 0.7 else 1.0
79+
else:
80+
user_feedback = 0.0 if random.random() < 0.3 else 1.0
81+
82+
capture_feedback(user_feedback, trace_id, ab_test_variant, "Hi Max")
83+
84+
5985
if __name__ == "__main__":
6086
main()

cookbook/evals_and_experiments/run_experiment_agreement_among_evals.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import random
21
from typing import List
2+
33
import os
4+
import random
45

56
from dotenv import load_dotenv
67

78
from parea import Parea, trace
8-
from parea.schemas import EvaluatedLog, Log, EvaluationResult
9+
from parea.schemas import EvaluatedLog, EvaluationResult, Log
910

1011
load_dotenv()
1112

@@ -14,10 +15,8 @@
1415

1516
def random_eval_factory(trial: int):
1617
def random_eval(log: Log) -> EvaluationResult:
17-
return EvaluationResult(
18-
score=1 if random.random() < 0.5 else 0,
19-
name=f'random_eval_{trial}'
20-
)
18+
return EvaluationResult(score=1 if random.random() < 0.5 else 0, name=f"random_eval_{trial}")
19+
2120
return random_eval
2221

2322

@@ -58,9 +57,4 @@ def percent_evals_agree(logs: List[EvaluatedLog]) -> float:
5857

5958
# You can optionally run the experiment manually by calling `.run()`
6059
if __name__ == "__main__":
61-
p.experiment(
62-
name="Greeting",
63-
data=data,
64-
func=starts_with_f,
65-
dataset_level_evals=[percent_evals_agree]
66-
).run()
60+
p.experiment(name="Greeting", data=data, func=starts_with_f, dataset_level_evals=[percent_evals_agree]).run()

parea/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
import uuid
66
from copy import deepcopy
7-
from datetime import datetime
7+
from datetime import datetime, timedelta
88

99
import pytz
1010
from attr import asdict, fields_dict
@@ -79,6 +79,7 @@ def serialize_values(metadata: Dict[str, Any]) -> Dict[str, str]:
7979

8080
def timezone_aware_now() -> datetime:
8181
return datetime.now(pytz.utc)
82+
# return datetime.now(pytz.utc) - timedelta(days=6)
8283

8384

8485
def structure_trace_log_from_api(d: dict) -> TraceLogTree:

0 commit comments

Comments
 (0)