-
Notifications
You must be signed in to change notification settings - Fork 4
Versioned Metadata for Reproducibility #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
45ba74e
08a0436
0be1267
a036a70
a69c1ca
bf56604
a8ee895
6e705a4
ba6022b
35f9d4f
32a18c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,18 +5,22 @@ | |
|
|
||
| import argparse | ||
| import json | ||
| import os | ||
| from dataclasses import dataclass, asdict, field | ||
| from datetime import datetime | ||
| from dataclasses import dataclass, field, asdict | ||
| from datetime import datetime, timezone | ||
| from functools import partial | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from openjury.evaluate import annotate_battles, PairScore | ||
| from openjury.evaluate import ( | ||
| annotate_battles, | ||
| PairScore, | ||
| resolve_judge_prompts, | ||
| ) | ||
| from openjury.generate import generate_instructions, generate_base | ||
| from openjury.instruction_dataset import load_instructions | ||
| from openjury.repro import write_run_metadata, _to_jsonable | ||
| from openjury.utils import data_root, read_df, download_hf | ||
| from openjury.utils import make_model, cache_function_dataframe | ||
|
|
||
|
|
@@ -272,6 +276,7 @@ def main(args: CliArgs): | |
| 3) create annotations | ||
| """ | ||
|
|
||
| run_started_at = datetime.now(timezone.utc) | ||
| print( | ||
| f"Using dataset {args.dataset} and evaluating models {args.model_A} and {args.model_B}." | ||
| ) | ||
|
|
@@ -382,6 +387,13 @@ def main(args: CliArgs): | |
| # the default system prompt of annotate is to compare instruction tuned models. | ||
|
|
||
| system_prompt = None | ||
| ( | ||
| effective_judge_system_prompt, | ||
| judge_user_prompt_template, | ||
| ) = resolve_judge_prompts( | ||
| provide_explanation=args.provide_explanation, | ||
| system_prompt=system_prompt, | ||
| ) | ||
| annotations = annotate_battles( | ||
| judge_chat_model=judge_chat_model, | ||
| instructions=instructions.head(n_instructions).tolist(), | ||
|
|
@@ -416,10 +428,6 @@ def main(args: CliArgs): | |
| res_folder = Path(args.result_folder) / name | ||
| res_folder.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # save argument for results analysis | ||
| with open(res_folder / f"args-{name}.json", "w") as f: | ||
| json.dump(asdict(args), f, indent=2) | ||
|
|
||
| print(f"Saving results to {res_folder}") | ||
| df = pd.DataFrame(annotations) | ||
| df["instruction_index"] = instructions.head(n_instructions).index.tolist() | ||
|
|
@@ -476,14 +484,57 @@ def main(args: CliArgs): | |
| "num_ties": num_ties, | ||
| "num_missing": num_battles - (num_losses + num_wins + num_ties), | ||
| "preferences": prefs.tolist(), | ||
| "date": str(datetime.now().isoformat()), | ||
| "user": os.getenv("USER", ""), | ||
| } | ||
| print(f"{args.model_A} vs {args.model_B} judged by {args.judge_model}") | ||
| print_results(results) | ||
|
|
||
| with open(res_folder / f"results-{name}.json", "w") as f: | ||
| json.dump(results, f, indent=2) | ||
| json.dump(_to_jsonable(results), f, indent=2, allow_nan=False) | ||
|
|
||
| try: | ||
| eval_instruction_index = instructions.head(n_instructions).index.tolist() | ||
| eval_instructions = instructions.head(n_instructions).tolist() | ||
| eval_completions_A = completions_A.head(n_instructions).tolist() | ||
| eval_completions_B = completions_B.head(n_instructions).tolist() | ||
|
|
||
| write_run_metadata( | ||
| output_dir=res_folder, | ||
| entrypoint="openjury.generate_and_evaluate.main", | ||
| run={ | ||
|
||
| "dataset": args.dataset, | ||
| "model_A": args.model_A, | ||
| "model_B": args.model_B, | ||
| "judge_model": args.judge_model, | ||
| "provide_explanation": args.provide_explanation, | ||
| "swap_mode": args.swap_mode, | ||
| "n_instructions": n_instructions, | ||
| "ignore_cache": args.ignore_cache, | ||
| "use_tqdm": args.use_tqdm, | ||
| "truncate_all_input_chars": args.truncate_all_input_chars, | ||
| "max_out_tokens_models": args.max_out_tokens_models, | ||
| "max_out_tokens_judge": args.max_out_tokens_judge, | ||
| "max_model_len": args.max_model_len, | ||
| "chat_template": args.chat_template, | ||
| }, | ||
| results=results, | ||
| input_payloads={ | ||
| "instruction_index": eval_instruction_index, | ||
| "instructions": eval_instructions, | ||
| "completions_A": eval_completions_A, | ||
| "completions_B": eval_completions_B, | ||
| }, | ||
| extras={ | ||
| "files": { | ||
| "annotations": f"{name}-annotations.csv", | ||
| "results": f"results-{name}.json", | ||
| } | ||
| }, | ||
| judge_system_prompt=effective_judge_system_prompt, | ||
| judge_user_prompt_template=judge_user_prompt_template, | ||
| started_at_utc=run_started_at, | ||
| ) | ||
| except Exception as e: | ||
| print(f"Warning: failed to write run metadata: {e}") | ||
|
||
|
|
||
| return prefs | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we test this function instead of the whole entrypoint?
Ideally with possible edge cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it with 35f9d4f. Now I have an
test_reprofunction that checks write run metadata only.