-
Notifications
You must be signed in to change notification settings - Fork 0
Standardized modelgauge column names #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
796c9a8
fd6419a
19a4b93
61f83e8
d37baa0
211f0ff
1913534
65bd132
f3f8912
163f6af
050d572
c2a82d1
5f5ff70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,23 +1,18 @@ | ||
| """Runway for annotating responses from SUTs. | ||
|
|
||
| TODO: PROMPT_CSV_INPUT_COLUMNS / ANNOTATOR_CSV_INPUT_COLUMNS should be aligned | ||
| """ | ||
| """Runway for annotating responses from SUTs.""" | ||
|
|
||
| import collections | ||
| import csv | ||
| import os | ||
| import pathlib | ||
| import tempfile | ||
| from typing import Any, Dict, List | ||
|
|
||
| import jsonlines | ||
| import mlflow | ||
| import numpy as np | ||
| from matplotlib import pyplot as plt | ||
| from modelgauge.annotation_pipeline import ANNOTATOR_CSV_INPUT_COLUMNS | ||
| from modelgauge.annotator import Annotator | ||
| from modelgauge.annotator_registry import ANNOTATORS | ||
| from modelgauge.annotator_set import AnnotatorSet | ||
| from modelgauge.dataset import AnnotationDataset | ||
| from modelgauge.ensemble_annotator_set import ENSEMBLE_STRATEGIES, EnsembleAnnotatorSet | ||
| from modelgauge.pipeline_runner import build_runner | ||
|
|
||
|
|
@@ -105,8 +100,7 @@ def annotate( | |
| dest_dir=tmp, | ||
| ) | ||
| input_data.log_input() | ||
| # TODO: maybe the transformation should be handled by the dataset class? | ||
| input_path = transform_annotation_file(src=input_data.local_path(), dest_dir=tmp) # type: ignore | ||
| input_path = input_data.local_path() # type: ignore | ||
| pipeline_kwargs["input_path"] = pathlib.Path(input_path) | ||
| pipeline_kwargs["output_dir"] = pathlib.Path(tmp) | ||
| pipeline_runner = build_runner(**pipeline_kwargs) | ||
|
|
@@ -182,38 +176,6 @@ def _get_annotators(annotator_ids: List[str]) -> Dict[str, Annotator]: | |
| return annotators | ||
|
|
||
|
|
||
| def transform_annotation_file(src: str, dest_dir: str) -> str: | ||
| transformed_suffix = "transformed.csv" | ||
| # rename columns | ||
| with open(src, "r") as f: | ||
| reader = csv.DictReader(f) | ||
| fieldnames = reader.fieldnames | ||
| if ( | ||
| fieldnames is None | ||
| or len(fieldnames) < 3 | ||
| or fieldnames[0] != "UID" | ||
| or fieldnames[1] != "Text" | ||
| ): | ||
| raise ValueError( | ||
| "Invalid input file. Must have columns: UID, Text, sut_id1, sut_id1." | ||
| ) | ||
| suts = fieldnames[2:] | ||
| with open(os.path.join(dest_dir, transformed_suffix), "w") as f_out: | ||
| writer = csv.DictWriter(f_out, fieldnames=ANNOTATOR_CSV_INPUT_COLUMNS) | ||
| writer.writeheader() | ||
| for row in reader: | ||
| for sut in suts: | ||
| writer.writerow( | ||
| { | ||
| "UID": row["UID"], | ||
| "Prompt": row["Text"], | ||
| "SUT": sut, | ||
| "Response": row[sut], | ||
| } | ||
| ) | ||
| return os.path.join(dest_dir, transformed_suffix) | ||
|
|
||
|
|
||
| def log_safety_summary( | ||
| annotator_uids: List[str], | ||
| data_path: str, | ||
|
|
@@ -223,35 +185,27 @@ def log_safety_summary( | |
| total_safe = collections.Counter() | ||
| total = collections.Counter() | ||
|
|
||
| with open(data_path, "r") as f: | ||
| reader = jsonlines.Reader(f) | ||
| for row in reader: | ||
| for annotator_uid in annotator_uids: | ||
| if ( | ||
| annotator_uid not in row["Annotations"] | ||
| or "is_safe" not in row["Annotations"][annotator_uid] | ||
| ): | ||
| continue | ||
| is_safe = row["Annotations"][annotator_uid]["is_safe"] | ||
| if "logprobs" in row["Annotations"][annotator_uid]: | ||
| all_log_probs[annotator_uid].append( | ||
| row["Annotations"][annotator_uid]["logprobs"] | ||
| ) | ||
|
|
||
| total[annotator_uid] += 1 | ||
| if is_safe: | ||
| total_safe[annotator_uid] += 1 | ||
|
|
||
| for annotator_uid in annotator_uids: | ||
| mlflow.log_metric(f"{annotator_uid}_total_safe", total_safe[annotator_uid]) | ||
| mlflow.log_metric(f"{annotator_uid}_total_count", total[annotator_uid]) | ||
| # TODO: the format for the log probs isn't always the same | ||
| # in particular, the private ensemble uses a different format | ||
| try: | ||
| log_stats(f"{annotator_uid}_logprobs_", all_log_probs[annotator_uid]) | ||
| log_hist(dir, f"{annotator_uid}", all_log_probs[annotator_uid]) | ||
| except Exception as e: | ||
| print(f"Failed to log stats for {annotator_uid}: {e}") | ||
| with AnnotationDataset(data_path, "r") as dataset: | ||
| for item in dataset: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not totally sure how this works when there are multiple annotators? If I'm reading the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That changed in the modelgauge PR! Every row is one response and one annotation.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's probably why the randomness changed! Since before it was looping row, and then a for loop around the annotators per row, but the order of the annotators may not match the order modelgauge is now producing. I'm good on this then! |
||
| if "is_safe" not in item.annotation: | ||
| continue | ||
| is_safe = item.annotation["is_safe"] | ||
| if "logprobs" in item.annotation: | ||
| all_log_probs[item.annotator_uid].append(item.annotation["logprobs"]) | ||
| total[item.annotator_uid] += 1 | ||
| if is_safe: | ||
| total_safe[item.annotator_uid] += 1 | ||
| # Log metrics. | ||
| for annotator_uid in annotator_uids: | ||
| mlflow.log_metric(f"{annotator_uid}_total_safe", total_safe[annotator_uid]) | ||
| mlflow.log_metric(f"{annotator_uid}_total_count", total[annotator_uid]) | ||
| # TODO: the format for the log probs isn't always the same | ||
| # in particular, the private ensemble uses a different format | ||
| try: | ||
| log_stats(f"{annotator_uid}_logprobs_", all_log_probs[annotator_uid]) | ||
| log_hist(dir, f"{annotator_uid}", all_log_probs[annotator_uid]) | ||
| except Exception as e: | ||
| print(f"Failed to log stats for {annotator_uid}: {e}") | ||
|
|
||
|
|
||
| def log_stats(tag_prefix, values): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| UID,SUT,is_safe | ||
| prompt_uid,sut_uid,is_safe | ||
| 0,demo_yes_no,safe | ||
| 1,demo_yes_no,safe | ||
| 2,demo_yes_no,safe | ||
|
|
||
bkorycki marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.