Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 63 additions & 10 deletions openjury/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import re
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
Expand All @@ -10,6 +10,7 @@
from langchain_core.language_models.llms import LLM

from openjury.instruction_dataset import load_instructions
from openjury.repro import write_run_metadata, _to_jsonable
from openjury.utils import (
read_df,
data_root,
Expand Down Expand Up @@ -65,6 +66,25 @@ def load_judge_system_and_user_prompt(
return system_prompt, user_prompt_template


def resolve_judge_prompts(
*,
provide_explanation: bool,
system_prompt: str | None = None,
user_prompt_template: str | None = None,
) -> tuple[str, str]:
default_system_prompt, default_user_prompt_template = (
load_judge_system_and_user_prompt(provide_explanation=provide_explanation)
)
return (
system_prompt if system_prompt is not None else default_system_prompt,
(
user_prompt_template
if user_prompt_template is not None
else default_user_prompt_template
),
)


def evaluate_completions(
dataset: str = "alpaca-eval",
judge_chat_model: LLM = None,
Expand All @@ -88,6 +108,7 @@ def evaluate_completions(
exceeding context limit
:return:
"""
run_started_at = datetime.now(timezone.utc)
local_path_tables = data_root / "tables"
download_hf(name=dataset, local_path=local_path_tables)

Expand Down Expand Up @@ -138,6 +159,11 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str):

judge_chat_model = Together(model="meta-llama/Llama-3.3-70B-Instruct-Turbo")

(
judge_system_prompt,
judge_user_prompt_template,
) = resolve_judge_prompts(provide_explanation=provide_explanation)

annotations = annotate_battles(
judge_chat_model=judge_chat_model,
instructions=instructions.tolist(),
Expand Down Expand Up @@ -174,7 +200,37 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str):
output_folder.mkdir(parents=True, exist_ok=True)
pd.DataFrame(annotations).to_csv(output_folder / "annotations.csv", index=False)
with open(output_folder / "results.json", "w") as f:
json.dump(results, f)
json.dump(_to_jsonable(results), f, allow_nan=False)

run_metadata = {
"dataset": dataset,
"method_A": method_A,
"method_B": method_B,
"num_annotations": num_annotations,
"n_annotations": len(instructions),
"use_tqdm": use_tqdm,
"truncate_input_chars": truncate_input_chars,
"provide_explanation": provide_explanation,
}

try:
write_run_metadata(
output_dir=output_folder,
entrypoint="openjury.evaluate.evaluate_completions",
run=run_metadata,
results=results,
input_payloads={
"instruction_index": instructions.index.tolist(),
"instructions": instructions.tolist(),
"completions_A": completions_A.loc[instructions.index].tolist(),
"completions_B": completions_B.loc[instructions.index].tolist(),
},
judge_system_prompt=judge_system_prompt,
judge_user_prompt_template=judge_user_prompt_template,
started_at_utc=run_started_at,
)
except OSError as e:
print(f"Warning: failed to write run metadata: {e}")


@dataclass
Expand Down Expand Up @@ -227,14 +283,11 @@ def annotate_battles(
# alternatively pass list of tuples
assert len(instructions) == len(completions_A) == len(completions_B)

(
default_system_prompt,
default_user_prompt_template,
) = load_judge_system_and_user_prompt(provide_explanation=provide_explanation)
if system_prompt is None:
system_prompt = default_system_prompt
if user_prompt_template is None:
user_prompt_template = default_user_prompt_template
system_prompt, user_prompt_template = resolve_judge_prompts(
provide_explanation=provide_explanation,
system_prompt=system_prompt,
user_prompt_template=user_prompt_template,
)

prompt_template = ChatPromptTemplate.from_messages(
[("system", system_prompt), ("user", user_prompt_template)]
Expand Down
52 changes: 41 additions & 11 deletions openjury/generate_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@

import argparse
import json
import os
from dataclasses import dataclass, asdict, field
from datetime import datetime
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd

from openjury.evaluate import annotate_battles, PairScore
from openjury.evaluate import (
annotate_battles,
PairScore,
resolve_judge_prompts,
)
from openjury.generate import generate_instructions, generate_base
from openjury.instruction_dataset import load_instructions
from openjury.repro import write_run_metadata, _to_jsonable
from openjury.utils import data_root, read_df, download_hf
from openjury.utils import make_model, cache_function_dataframe

Expand Down Expand Up @@ -272,6 +276,7 @@ def main(args: CliArgs):
3) create annotations
"""

run_started_at = datetime.now(timezone.utc)
print(
f"Using dataset {args.dataset} and evaluating models {args.model_A} and {args.model_B}."
)
Expand Down Expand Up @@ -382,6 +387,13 @@ def main(args: CliArgs):
# the default system prompt of annotate is to compare instruction tuned models.

system_prompt = None
(
effective_judge_system_prompt,
judge_user_prompt_template,
) = resolve_judge_prompts(
provide_explanation=args.provide_explanation,
system_prompt=system_prompt,
)
annotations = annotate_battles(
judge_chat_model=judge_chat_model,
instructions=instructions.head(n_instructions).tolist(),
Expand Down Expand Up @@ -416,10 +428,6 @@ def main(args: CliArgs):
res_folder = Path(args.result_folder) / name
res_folder.mkdir(parents=True, exist_ok=True)

# save argument for results analysis
with open(res_folder / f"args-{name}.json", "w") as f:
json.dump(asdict(args), f, indent=2)

print(f"Saving results to {res_folder}")
df = pd.DataFrame(annotations)
df["instruction_index"] = instructions.head(n_instructions).index.tolist()
Expand Down Expand Up @@ -476,14 +484,36 @@ def main(args: CliArgs):
"num_ties": num_ties,
"num_missing": num_battles - (num_losses + num_wins + num_ties),
"preferences": prefs.tolist(),
"date": str(datetime.now().isoformat()),
"user": os.getenv("USER", ""),
}
print(f"{args.model_A} vs {args.model_B} judged by {args.judge_model}")
print_results(results)

with open(res_folder / f"results-{name}.json", "w") as f:
json.dump(results, f, indent=2)
json.dump(_to_jsonable(results), f, indent=2, allow_nan=False)

eval_instruction_index = instructions.head(n_instructions).index.tolist()
eval_instructions = instructions.head(n_instructions).tolist()
eval_completions_A = completions_A.head(n_instructions).tolist()
eval_completions_B = completions_B.head(n_instructions).tolist()

try:
write_run_metadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test this function instead of the whole entrypoint?
Ideally with possible edge cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it with 35f9d4f. Now I have an test_repro function that checks write run metadata only.

output_dir=res_folder,
entrypoint="openjury.generate_and_evaluate.main",
run=asdict(args),
results=results,
input_payloads={
"instruction_index": eval_instruction_index,
"instructions": eval_instructions,
"completions_A": eval_completions_A,
"completions_B": eval_completions_B,
},
judge_system_prompt=effective_judge_system_prompt,
judge_user_prompt_template=judge_user_prompt_template,
started_at_utc=run_started_at,
)
except OSError as e:
print(f"Warning: failed to write run metadata: {e}")

return prefs

Expand Down
Loading
Loading