Skip to content

Commit e440a9d

Browse files
authored
Merge pull request #186 from SinclairHudson/adding-tests
Moving metrics to metrics and adding pass/fail LLM tests
2 parents 7598e22 + d5b7843 commit e440a9d

File tree

8 files changed

+402
-315
lines changed

8 files changed

+402
-315
lines changed

llmtune/cli/toolkit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from llmtune.finetune.lora import LoRAFinetune
1616
from llmtune.inference.lora import LoRAInference
1717
from llmtune.pydantic_models.config_model import Config
18-
from llmtune.qa.generics import LLMTestSuite
19-
from llmtune.qa.qa_tests import QaTestRegistry
18+
from llmtune.qa.generics import LLMMetricSuite
19+
from llmtune.qa.qa_metrics import QaMetricRegistry
2020
from llmtune.ui.rich_ui import RichUI
2121
from llmtune.utils.ablation_utils import generate_permutations
2222
from llmtune.utils.save_utils import DirectoryHelper
@@ -91,10 +91,10 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
9191
qa_file_path = dir_helper.save_paths.qa_file
9292
if not qa_file_path.exists():
9393
llm_metrics = config.qa.llm_metrics
94-
tests = QaTestRegistry.create_tests_from_list(llm_metrics)
95-
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
96-
test_suite.save_test_results(qa_file_path)
97-
test_suite.print_test_results()
94+
tests = QaMetricRegistry.create_tests_from_list(llm_metrics)
95+
test_suite = LLMMetricSuite.from_csv(results_file_path, tests)
96+
test_suite.save_metric_results(qa_file_path)
97+
test_suite.print_metric_results()
9898

9999

100100
@app.command("run")

llmtune/qa/generics.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,78 @@
11
import statistics
2-
from abc import ABC, abstractmethod
32
from pathlib import Path
43
from typing import Dict, List, Union
54

65
import pandas as pd
76

7+
from llmtune.qa.qa_metrics import LLMQaMetric
88
from llmtune.ui.rich_ui import RichUI
99

1010

11-
class LLMQaTest(ABC):
12-
@property
13-
@abstractmethod
14-
def test_name(self) -> str:
15-
pass
16-
17-
@abstractmethod
18-
def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int, bool]:
19-
pass
20-
11+
class LLMMetricSuite:
12+
"""
13+
Represents and runs a suite of metrics on a set of prompts,
14+
golden responses, and model predictions.
15+
"""
2116

22-
class LLMTestSuite:
2317
def __init__(
2418
self,
25-
tests: List[LLMQaTest],
19+
metrics: List[LLMQaMetric],
2620
prompts: List[str],
2721
ground_truths: List[str],
2822
model_preds: List[str],
2923
) -> None:
30-
self.tests = tests
24+
self.metrics = metrics
3125
self.prompts = prompts
3226
self.ground_truths = ground_truths
3327
self.model_preds = model_preds
3428

35-
self._results = {}
29+
self._results: Dict[str, List[Union[float, int]]] = {}
3630

3731
@staticmethod
3832
def from_csv(
3933
file_path: str,
40-
tests: List[LLMQaTest],
34+
metrics: List[LLMQaMetric],
4135
prompt_col: str = "Prompt",
4236
gold_col: str = "Ground Truth",
4337
pred_col="Predicted",
44-
) -> "LLMTestSuite":
38+
) -> "LLMMetricSuite":
4539
results_df = pd.read_csv(file_path)
4640
prompts = results_df[prompt_col].tolist()
4741
ground_truths = results_df[gold_col].tolist()
4842
model_preds = results_df[pred_col].tolist()
49-
return LLMTestSuite(tests, prompts, ground_truths, model_preds)
43+
return LLMMetricSuite(metrics, prompts, ground_truths, model_preds)
5044

51-
def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
52-
test_results = {}
53-
for test in self.tests:
54-
metrics = []
45+
def compute_metrics(self) -> Dict[str, List[Union[float, int]]]:
46+
results = {}
47+
for metric in self.metrics:
48+
metric_results = []
5549
for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds):
56-
metrics.append(test.get_metric(prompt, ground_truth, model_pred))
57-
test_results[test.test_name] = metrics
50+
metric_results.append(metric.get_metric(prompt, ground_truth, model_pred))
51+
results[metric.metric_name] = metric_results
5852

59-
self._results = test_results
60-
return test_results
53+
self._results = results
54+
return results
6155

6256
@property
63-
def test_results(self):
64-
return self._results if self._results else self.run_tests()
57+
def metric_results(self) -> Dict[str, List[Union[float, int]]]:
58+
return self._results if self._results else self.compute_metrics()
6559

66-
def print_test_results(self):
67-
result_dictionary = self.test_results
60+
def print_metric_results(self):
61+
result_dictionary = self.metric_results
6862
column_data = {key: list(result_dictionary[key]) for key in result_dictionary}
6963
mean_values = {key: statistics.mean(column_data[key]) for key in column_data}
7064
median_values = {key: statistics.median(column_data[key]) for key in column_data}
7165
stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data}
7266
# Use the RichUI class to display the table
7367
RichUI.qa_display_metric_table(result_dictionary, mean_values, median_values, stdev_values)
7468

75-
def save_test_results(self, path: str):
69+
def save_metric_results(self, path: str):
7670
# TODO: save these!
7771
path = Path(path)
7872
dir = path.parent
7973

8074
if not dir.exists():
8175
dir.mkdir(parents=True, exist_ok=True)
8276

83-
resultant_dataframe = pd.DataFrame(self.test_results)
77+
resultant_dataframe = pd.DataFrame(self.metric_results)
8478
resultant_dataframe.to_csv(path, index=False)

llmtune/qa/qa_metrics.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Union
3+
4+
import nltk
5+
import numpy as np
6+
import torch
7+
from langchain.evaluation import JsonValidityEvaluator
8+
from nltk import pos_tag
9+
from nltk.corpus import stopwords
10+
from nltk.tokenize import word_tokenize
11+
from rouge_score import rouge_scorer
12+
from transformers import DistilBertModel, DistilBertTokenizer
13+
14+
15+
json_validity_evaluator = JsonValidityEvaluator()
16+
17+
nltk.download("stopwords")
18+
nltk.download("punkt")
19+
nltk.download("averaged_perceptron_tagger")
20+
21+
22+
class LLMQaMetric(ABC):
23+
"""
24+
Abstract base class for a metric. A metric can be computed over a single
25+
data instance, and outputs a scalar value (integer or float).
26+
"""
27+
28+
@property
29+
@abstractmethod
30+
def metric_name(self) -> str:
31+
pass
32+
33+
@abstractmethod
34+
def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int]:
35+
pass
36+
37+
38+
class QaMetricRegistry:
39+
registry = {}
40+
41+
@classmethod
42+
def register(cls, *names):
43+
def inner_wrapper(wrapped_class):
44+
for name in names:
45+
cls.registry[name] = wrapped_class
46+
return wrapped_class
47+
48+
return inner_wrapper
49+
50+
@classmethod
51+
def create_tests_from_list(cls, metric_names: List[str]) -> List[LLMQaMetric]:
52+
return [cls.registry[test]() for test in metric_names]
53+
54+
55+
@QaMetricRegistry.register("summary_length")
56+
class LengthMetric(LLMQaMetric):
57+
@property
58+
def metric_name(self) -> str:
59+
return "summary_length"
60+
61+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
62+
return abs(len(ground_truth) - len(model_prediction))
63+
64+
65+
@QaMetricRegistry.register("jaccard_similarity")
66+
class JaccardSimilarityMetric(LLMQaMetric):
67+
@property
68+
def metric_name(self) -> str:
69+
return "jaccard_similarity"
70+
71+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
72+
set_ground_truth = set(ground_truth.lower())
73+
set_model_prediction = set(model_prediction.lower())
74+
75+
intersection_size = len(set_ground_truth.intersection(set_model_prediction))
76+
union_size = len(set_ground_truth.union(set_model_prediction))
77+
78+
similarity = intersection_size / union_size if union_size != 0 else 0
79+
return float(similarity)
80+
81+
82+
@QaMetricRegistry.register("dot_product")
83+
class DotProductSimilarityMetric(LLMQaMetric):
84+
"""Encodes both the ground truth and model prediction using DistilBERT, and
85+
computes the dot product similarity between the two embeddings."""
86+
87+
def __init__(self):
88+
model_name = "distilbert-base-uncased"
89+
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
90+
self.model = DistilBertModel.from_pretrained(model_name)
91+
92+
@property
93+
def metric_name(self) -> str:
94+
return "dot_product"
95+
96+
def _encode_sentence(self, sentence):
97+
tokens = self.tokenizer(sentence, return_tensors="pt")
98+
with torch.no_grad():
99+
outputs = self.model(**tokens)
100+
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
101+
102+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
103+
embedding_ground_truth = self._encode_sentence(ground_truth)
104+
embedding_model_prediction = self._encode_sentence(model_prediction)
105+
dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction)
106+
return float(dot_product_similarity)
107+
108+
109+
@QaMetricRegistry.register("rouge_score")
110+
class RougeScoreMetric(LLMQaMetric):
111+
@property
112+
def metric_name(self) -> str:
113+
return "rouge_score"
114+
115+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
116+
scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
117+
scores = scorer.score(model_prediction, ground_truth)
118+
return float(scores["rouge1"].precision)
119+
120+
121+
@QaMetricRegistry.register("word_overlap")
122+
class WordOverlapMetric(LLMQaMetric):
123+
@property
124+
def metric_name(self) -> str:
125+
return "word_overlap"
126+
127+
def _remove_stopwords(self, text: str) -> str:
128+
stop_words = set(stopwords.words("english"))
129+
word_tokens = word_tokenize(text)
130+
filtered_text = [word for word in word_tokens if word.lower() not in stop_words]
131+
return " ".join(filtered_text)
132+
133+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
134+
cleaned_model_prediction = self._remove_stopwords(model_prediction)
135+
cleaned_ground_truth = self._remove_stopwords(ground_truth)
136+
137+
words_model_prediction = set(cleaned_model_prediction.split())
138+
words_ground_truth = set(cleaned_ground_truth.split())
139+
140+
common_words = words_model_prediction.intersection(words_ground_truth)
141+
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
142+
return float(overlap_percentage)
143+
144+
145+
@QaMetricRegistry.register("json_valid")
146+
class JSONValidityMetric(LLMQaMetric):
147+
"""
148+
Checks to see if valid json can be parsed from the model output, according
149+
to langchain_core.utils.json.parse_json_markdown
150+
The JSON can be wrapped in markdown and this test will still pass
151+
"""
152+
153+
@property
154+
def metric_name(self) -> str:
155+
return "json_valid"
156+
157+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
158+
result = json_validity_evaluator.evaluate_strings(prediction=model_prediction)
159+
binary_res = result["score"]
160+
return float(binary_res)
161+
162+
163+
class PosCompositionMetric(LLMQaMetric):
164+
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
165+
words = word_tokenize(text)
166+
tags = pos_tag(words)
167+
pos_words = [word for word, tag in tags if tag in pos_tags]
168+
total_words = len(text.split(" "))
169+
return round(len(pos_words) / total_words, 2)
170+
171+
172+
@QaMetricRegistry.register("verb_percent")
173+
class VerbPercentMetric(PosCompositionMetric):
174+
@property
175+
def metric_name(self) -> str:
176+
return "verb_percent"
177+
178+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
179+
return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"])
180+
181+
182+
@QaMetricRegistry.register("adjective_percent")
183+
class AdjectivePercentMetric(PosCompositionMetric):
184+
@property
185+
def metric_name(self) -> str:
186+
return "adjective_percent"
187+
188+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
189+
return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"])
190+
191+
192+
@QaMetricRegistry.register("noun_percent")
193+
class NounPercentMetric(PosCompositionMetric):
194+
@property
195+
def metric_name(self) -> str:
196+
return "noun_percent"
197+
198+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
199+
return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"])
200+
201+
202+
# Instantiate tests
203+
# length_test = LengthMetric()
204+
# jaccard_similarity_test = JaccardSimilarityMetric()
205+
# dot_product_similarity_test = DotProductSimilarityMetric()
206+
# rouge_score_test = RougeScoreMetric()
207+
# word_overlap_test = WordOverlapMetric()
208+
# verb_percent_test = VerbPercentMetric()
209+
# adjective_percent_test = AdjectivePercentMetric()
210+
# noun_percent_test = NounPercentMetric()

0 commit comments

Comments
 (0)