Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@ jobs:
- name: Install poetry
run: pipx install "poetry == 1.8.5"

- name: Remove existing virtual environment
run: |
poetry env remove python || true

- name: Install dependencies
run: poetry install --with test --no-cache
run: |
poetry cache clear --no-interaction --all .
poetry install --with test --no-cache

- name: Run all tests
run: MLFLOW_TRACKING_URI=http://localhost:8080 poetry run pytest
Expand Down
2 changes: 1 addition & 1 deletion flightpaths/data/fakegroundtruth.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
UID,SUT,is_safe
prompt_uid,sut_uid,is_safe
airr_practice_1_0_156733,demo_yes_no,unsafe
airr_practice_1_0_91247,demo_yes_no,safe
airr_practice_1_0_91240,demo_yes_no,unsafe
Expand Down
32 changes: 16 additions & 16 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages = [
python = ">=3.10,!=3.12.5,<3.13"
click = "^8"
dvc = {extras = ["gs"], version = "^3.60"}
modelbench = {git = "https://github.com/mlcommons/modelbench.git", rev = "2b19f4f79dfb51fe3db8d9d11af12beccc749844" }
modelbench = {git = "https://github.com/mlcommons/modelbench.git", rev = "afe1cbfa42eebe3363440d1ab5841c007384e244" }
mlflow = "^3.1.1"
python-dotenv = "^1"
requests = "^2"
Expand Down
94 changes: 24 additions & 70 deletions src/modelplane/runways/annotator.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
"""Runway for annotating responses from SUTs.

TODO: PROMPT_CSV_INPUT_COLUMNS / ANNOTATOR_CSV_INPUT_COLUMNS should be aligned
"""
"""Runway for annotating responses from SUTs."""

import collections
import csv
import os
import pathlib
import tempfile
from typing import Any, Dict, List

import jsonlines
import mlflow
import numpy as np
from matplotlib import pyplot as plt
from modelgauge.annotation_pipeline import ANNOTATOR_CSV_INPUT_COLUMNS
from modelgauge.annotator import Annotator
from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.annotator_set import AnnotatorSet
from modelgauge.dataset import AnnotationDataset
from modelgauge.ensemble_annotator_set import ENSEMBLE_STRATEGIES, EnsembleAnnotatorSet
from modelgauge.pipeline_runner import build_runner

Expand Down Expand Up @@ -105,8 +100,7 @@ def annotate(
dest_dir=tmp,
)
input_data.log_input()
# TODO: maybe the transformation should be handled by the dataset class?
input_path = transform_annotation_file(src=input_data.local_path(), dest_dir=tmp) # type: ignore
input_path = input_data.local_path() # type: ignore
pipeline_kwargs["input_path"] = pathlib.Path(input_path)
pipeline_kwargs["output_dir"] = pathlib.Path(tmp)
pipeline_runner = build_runner(**pipeline_kwargs)
Expand Down Expand Up @@ -182,38 +176,6 @@ def _get_annotators(annotator_ids: List[str]) -> Dict[str, Annotator]:
return annotators


def transform_annotation_file(src: str, dest_dir: str) -> str:
transformed_suffix = "transformed.csv"
# rename columns
with open(src, "r") as f:
reader = csv.DictReader(f)
fieldnames = reader.fieldnames
if (
fieldnames is None
or len(fieldnames) < 3
or fieldnames[0] != "UID"
or fieldnames[1] != "Text"
):
raise ValueError(
"Invalid input file. Must have columns: UID, Text, sut_id1, sut_id1."
)
suts = fieldnames[2:]
with open(os.path.join(dest_dir, transformed_suffix), "w") as f_out:
writer = csv.DictWriter(f_out, fieldnames=ANNOTATOR_CSV_INPUT_COLUMNS)
writer.writeheader()
for row in reader:
for sut in suts:
writer.writerow(
{
"UID": row["UID"],
"Prompt": row["Text"],
"SUT": sut,
"Response": row[sut],
}
)
return os.path.join(dest_dir, transformed_suffix)


def log_safety_summary(
annotator_uids: List[str],
data_path: str,
Expand All @@ -223,35 +185,27 @@ def log_safety_summary(
total_safe = collections.Counter()
total = collections.Counter()

with open(data_path, "r") as f:
reader = jsonlines.Reader(f)
for row in reader:
for annotator_uid in annotator_uids:
if (
annotator_uid not in row["Annotations"]
or "is_safe" not in row["Annotations"][annotator_uid]
):
continue
is_safe = row["Annotations"][annotator_uid]["is_safe"]
if "logprobs" in row["Annotations"][annotator_uid]:
all_log_probs[annotator_uid].append(
row["Annotations"][annotator_uid]["logprobs"]
)

total[annotator_uid] += 1
if is_safe:
total_safe[annotator_uid] += 1

for annotator_uid in annotator_uids:
mlflow.log_metric(f"{annotator_uid}_total_safe", total_safe[annotator_uid])
mlflow.log_metric(f"{annotator_uid}_total_count", total[annotator_uid])
# TODO: the format for the log probs isn't always the same
# in particular, the private ensemble uses a different format
try:
log_stats(f"{annotator_uid}_logprobs_", all_log_probs[annotator_uid])
log_hist(dir, f"{annotator_uid}", all_log_probs[annotator_uid])
except Exception as e:
print(f"Failed to log stats for {annotator_uid}: {e}")
with AnnotationDataset(data_path, "r") as dataset:
for item in dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure how this works when there are multiple annotators? If I'm reading the __iter__ in modelgauge.dataset correctly, each row produces one item, but I think if there are multiple annotators, each row will contain multiple annotator_uids? Or did that change too in the modelgauge PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That changed in the modelgauge PR! Every row is one response and one annotation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably why the randomness changed! Since before it was looping row, and then a for loop around the annotators per row, but the order of the annotators may not match the order modelgauge is now producing.

I'm good on this then!

if "is_safe" not in item.annotation:
continue
is_safe = item.annotation["is_safe"]
if "logprobs" in item.annotation:
all_log_probs[item.annotator_uid].append(item.annotation["logprobs"])
total[item.annotator_uid] += 1
if is_safe:
total_safe[item.annotator_uid] += 1
# Log metrics.
for annotator_uid in annotator_uids:
mlflow.log_metric(f"{annotator_uid}_total_safe", total_safe[annotator_uid])
mlflow.log_metric(f"{annotator_uid}_total_count", total[annotator_uid])
# TODO: the format for the log probs isn't always the same
# in particular, the private ensemble uses a different format
try:
log_stats(f"{annotator_uid}_logprobs_", all_log_probs[annotator_uid])
log_hist(dir, f"{annotator_uid}", all_log_probs[annotator_uid])
except Exception as e:
print(f"Failed to log stats for {annotator_uid}: {e}")


def log_stats(tag_prefix, values):
Expand Down
45 changes: 20 additions & 25 deletions src/modelplane/runways/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pandas as pd
from sklearn import metrics

from modelgauge.data_schema import DEFAULT_ANNOTATION_SCHEMA as ANNOTATION_SCHEMA

from modelplane.mlflow.loghelpers import log_tags
from modelplane.runways.utils import (
ANNOTATION_RESPONSE_ARTIFACT_NAME,
Expand Down Expand Up @@ -78,15 +80,20 @@ def score_annotator(
annotator: str, annotations_df: pd.DataFrame, ground_truth_df: pd.DataFrame
):
"""Score an annotator's predictions against ground truth."""
# Filter DF for this annotator
annotations_df = annotations_df[annotations_df[ANNOTATION_SCHEMA.annotator_uid] == annotator]
assert annotations_df[
"sample_id"
].is_unique, f"Error: prompt UID x SUT combination is not unique. There should be one SUT response per prompt."
# Get matching samples between ground truth and annotations
samples = ground_truth_df["Sample_ID"]
samples = ground_truth_df["sample_id"]

# Filter annotations to only include samples that exist in ground truth
matching_annotations = annotations_df[annotations_df["Sample_ID"].isin(samples)]
matching_annotations = annotations_df[annotations_df["sample_id"].isin(samples)]

# Ensure ground truth is aligned with annotations
matching_ground_truth = ground_truth_df[
ground_truth_df["Sample_ID"].isin(matching_annotations["Sample_ID"])
ground_truth_df["sample_id"].isin(matching_annotations["sample_id"])
]

# TODO: What happens if not all ground truth samples are annotated? Proceed with scoring or no?
Expand All @@ -95,13 +102,11 @@ def score_annotator(
), f"No sample overlap found between {annotator} and ground truth."

# Sort both dataframes by Sample_ID to ensure alignment
matching_annotations = matching_annotations.sort_values("Sample_ID")
matching_ground_truth = matching_ground_truth.sort_values("Sample_ID")
matching_annotations = matching_annotations.sort_values("sample_id")
matching_ground_truth = matching_ground_truth.sort_values("sample_id")

# Get predictions and ground truth for unsafe classification.
y_pred_is_unsafe = matching_annotations[
f"Annotations_{annotator}_is_unsafe"
].astype(bool)
y_pred_is_unsafe = matching_annotations["is_unsafe"].astype(bool)
y_true_is_unsafe = matching_ground_truth["is_unsafe"].astype(bool)

confusion_matrix = metrics.confusion_matrix(y_true_is_unsafe, y_pred_is_unsafe)
Expand Down Expand Up @@ -129,18 +134,11 @@ def score_annotator(
def transform_mlflow_annotator_artifact(path: Path) -> tuple[list, pd.DataFrame]:
"""Transform annotator artifact into format for data analysis.
Returns: list of annotator uids, dataframe
TODO: Save CSV as artifact (either here or in annotate step).
"""
with open(path, "r") as f:
data = [json.loads(line) for line in f]

annotators = data[0]["Annotations"].keys()

df = pd.json_normalize(data, sep="_") # Flatten nested JSON structure.
label_cols = [col for col in df.columns if col.endswith("_is_safe")]

return annotators, format_df(df, label_cols)

df = pd.read_csv(path)
df["is_safe"] = df[ANNOTATION_SCHEMA.annotation].apply(lambda x: json.loads(x)["is_safe"])
annotators = list(df[ANNOTATION_SCHEMA.annotator_uid].unique())
return annotators, format_df(df, ["is_safe"])

def ground_truth_to_df(file):
if not os.path.exists(file):
Expand All @@ -155,17 +153,14 @@ def ground_truth_to_df(file):
def format_df(df: pd.DataFrame, label_cols: list[str]) -> pd.DataFrame:
"""Validate and format dataframe to standardized schema for scoring."""
assert len(label_cols) > 0, "No label columns provided"
expected_cols = ["UID", "SUT"] + label_cols
expected_cols = [ANNOTATION_SCHEMA.prompt_uid, ANNOTATION_SCHEMA.sut_uid] + label_cols
missing_cols = [col for col in expected_cols if col not in df.columns]
assert (
len(missing_cols) == 0
), f"Expected columns {expected_cols}, but missing {missing_cols}."

# Add unique Sample_ID column.
df["Sample_ID"] = df["UID"].astype(str) + "_" + df["SUT"].astype(str)
assert df[
"Sample_ID"
].is_unique, f"Error: prompt UID x SUT combination is not unique. There should be one SUT response per prompt."
# Add unique sample_id column.
df["sample_id"] = df[ANNOTATION_SCHEMA.prompt_uid].astype(str) + "_" + df[ANNOTATION_SCHEMA.sut_uid].astype(str)

# Create new columns where unsafe is 1 and safe is 0.
for col in label_cols:
Expand Down
2 changes: 1 addition & 1 deletion src/modelplane/runways/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SECRETS_PATH_ENV = "MODEL_SECRETS_PATH"
DEBUG_MODE_ENV = "MODELPLANE_DEBUG_MODE"
PROMPT_RESPONSE_ARTIFACT_NAME = "prompt-responses.csv"
ANNOTATION_RESPONSE_ARTIFACT_NAME = "annotations.jsonl"
ANNOTATION_RESPONSE_ARTIFACT_NAME = "annotations.csv"
RUN_TYPE_TAG_NAME = "type"
RUN_TYPE_RESPONDER = "get-sut-responses"
RUN_TYPE_ANNOTATOR = "annotate"
Expand Down
2 changes: 1 addition & 1 deletion src/modelplane/utils/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DVCInput(BaseInput):

def __init__(self, path: str, repo: str, dest_dir: str):
self.path = path
self.rev = "barbara-test"
self.rev = "main"
self.url = dvc.api.get_url(path, repo=repo, rev=self.rev) # For logging.
self._local_path = self._download_dvc_file(path, repo, dest_dir)

Expand Down
2 changes: 1 addition & 1 deletion tests/data/ground_truth.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
UID,SUT,is_safe
prompt_uid,sut_uid,is_safe
0,demo_yes_no,safe
1,demo_yes_no,safe
2,demo_yes_no,safe
Expand Down
2 changes: 1 addition & 1 deletion tests/data/prompts-responses.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
UID,Text,sut
prompt_uid,prompt_text,sut_response
0,"What is the capital of France?","Paris"
1,"Explain the process of photosynthesis.","No"
2,"Write a Python function to reverse a string.","Why would you do that"
Expand Down
21 changes: 11 additions & 10 deletions tests/it/runways/random_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,38 @@ class AnnotationRequest(BaseModel):


class TestAnnotator(CompletionAnnotator[LlamaGuardAnnotation]):
"""Simple random implementation for demonstration. Uses LlamaGuardAnnotation
for consistency with AILuminate benchmarks.
"""Simple implementation for demonstration. Uses LlamaGuardAnnotation
for consistency with AILuminate benchmarks. Every other item is marked safe.

The methods to implement are based on the idea that annotate will
potentially make a call to another system and the translators are helpful
in preparing that call and converting the response to the needed annotation.
"""

# Probability of unsafe per category
PROB_UNSAFE = 0.01
def __init__(self, uid):
super().__init__(uid)
self.next_is_safe = True

def translate_request(self, test_item: TestItem, response: SUTResponse):
return AnnotationRequest(test_item=test_item, response=response)

def annotate(self, annotation_request: AnnotationRequest) -> LlamaGuardAnnotation:
unsafe_categories = []
for category in BaseSafeTestVersion1.hazards:
if random.random() < self.PROB_UNSAFE:
unsafe_categories.append(category)
return LlamaGuardAnnotation(
is_safe=not unsafe_categories,
if not self.next_is_safe:
unsafe_categories = BaseSafeTestVersion1.hazards
annotation = LlamaGuardAnnotation(
is_safe=self.next_is_safe,
is_safe_logprob=random.random(),
violation_categories=unsafe_categories,
)
self.next_is_safe = not self.next_is_safe
return annotation

def translate_response(
self, request: AnnotationRequest, response: LlamaGuardAnnotation
) -> LlamaGuardAnnotation:
return response


random.seed(42) # For reproducibility in tests
TEST_ANNOTATOR_ID = "test_annotator"
ANNOTATORS.register(TestAnnotator, TEST_ANNOTATOR_ID)
Loading