Skip to content

Commit 63a7b30

Browse files
renamed all internals to LLM metrics, WIP sketch of Test class
1 parent 7598e22 commit 63a7b30

File tree

7 files changed

+317
-272
lines changed

7 files changed

+317
-272
lines changed

llmtune/cli/toolkit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from llmtune.finetune.lora import LoRAFinetune
1616
from llmtune.inference.lora import LoRAInference
1717
from llmtune.pydantic_models.config_model import Config
18-
from llmtune.qa.generics import LLMTestSuite
19-
from llmtune.qa.qa_tests import QaTestRegistry
18+
from llmtune.qa.generics import LLMMetricSuite
19+
from llmtune.qa.qa_metrics import QaMetricRegistry
2020
from llmtune.ui.rich_ui import RichUI
2121
from llmtune.utils.ablation_utils import generate_permutations
2222
from llmtune.utils.save_utils import DirectoryHelper
@@ -91,10 +91,10 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
9191
qa_file_path = dir_helper.save_paths.qa_file
9292
if not qa_file_path.exists():
9393
llm_metrics = config.qa.llm_metrics
94-
tests = QaTestRegistry.create_tests_from_list(llm_metrics)
95-
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
96-
test_suite.save_test_results(qa_file_path)
97-
test_suite.print_test_results()
94+
tests = QaMetricRegistry.create_tests_from_list(llm_metrics)
95+
test_suite = LLMMetricSuite.from_csv(results_file_path, tests)
96+
test_suite.save_metric_results(qa_file_path)
97+
test_suite.print_metric_results()
9898

9999

100100
@app.command("run")

llmtune/qa/generics.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,80 +5,74 @@
55

66
import pandas as pd
77

8+
from llmtune.qa.qa_metrics import LLMQaMetric
89
from llmtune.ui.rich_ui import RichUI
910

1011

11-
class LLMQaTest(ABC):
12-
@property
13-
@abstractmethod
14-
def test_name(self) -> str:
15-
pass
16-
17-
@abstractmethod
18-
def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int, bool]:
19-
pass
20-
21-
22-
class LLMTestSuite:
12+
class LLMMetricSuite:
13+
"""
14+
Represents and runs a suite of metrics on a set of prompts,
15+
golden responses, and model predictions.
16+
"""
2317
def __init__(
2418
self,
25-
tests: List[LLMQaTest],
19+
metrics: List[LLMQaMetric],
2620
prompts: List[str],
2721
ground_truths: List[str],
2822
model_preds: List[str],
2923
) -> None:
30-
self.tests = tests
24+
self.metrics = metrics
3125
self.prompts = prompts
3226
self.ground_truths = ground_truths
3327
self.model_preds = model_preds
3428

35-
self._results = {}
29+
self._results: Dict[str, List[Union[float, int]]] = {}
3630

3731
@staticmethod
3832
def from_csv(
3933
file_path: str,
40-
tests: List[LLMQaTest],
34+
metrics: List[LLMQaMetric],
4135
prompt_col: str = "Prompt",
4236
gold_col: str = "Ground Truth",
4337
pred_col="Predicted",
44-
) -> "LLMTestSuite":
38+
) -> "LLMMetricSuite":
4539
results_df = pd.read_csv(file_path)
4640
prompts = results_df[prompt_col].tolist()
4741
ground_truths = results_df[gold_col].tolist()
4842
model_preds = results_df[pred_col].tolist()
49-
return LLMTestSuite(tests, prompts, ground_truths, model_preds)
43+
return LLMMetricSuite(metrics, prompts, ground_truths, model_preds)
5044

51-
def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
52-
test_results = {}
53-
for test in self.tests:
54-
metrics = []
45+
def compute_metrics(self) -> Dict[str, List[Union[float, int]]]:
46+
results = {}
47+
for metric in self.metrics:
48+
metric_results = []
5549
for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds):
56-
metrics.append(test.get_metric(prompt, ground_truth, model_pred))
57-
test_results[test.test_name] = metrics
50+
metric_results.append(metric.get_metric(prompt, ground_truth, model_pred))
51+
results[metric.metric_name] = metric_results
5852

59-
self._results = test_results
60-
return test_results
53+
self._results = results
54+
return results
6155

6256
@property
63-
def test_results(self):
64-
return self._results if self._results else self.run_tests()
57+
def metric_results(self) -> Dict[str, List[Union[float, int]]]:
58+
return self._results if self._results else self.compute_metrics()
6559

66-
def print_test_results(self):
67-
result_dictionary = self.test_results
60+
def print_metric_results(self):
61+
result_dictionary = self.metric_results
6862
column_data = {key: list(result_dictionary[key]) for key in result_dictionary}
6963
mean_values = {key: statistics.mean(column_data[key]) for key in column_data}
7064
median_values = {key: statistics.median(column_data[key]) for key in column_data}
7165
stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data}
7266
# Use the RichUI class to display the table
7367
RichUI.qa_display_metric_table(result_dictionary, mean_values, median_values, stdev_values)
7468

75-
def save_test_results(self, path: str):
69+
def save_metric_results(self, path: str):
7670
# TODO: save these!
7771
path = Path(path)
7872
dir = path.parent
7973

8074
if not dir.exists():
8175
dir.mkdir(parents=True, exist_ok=True)
8276

83-
resultant_dataframe = pd.DataFrame(self.test_results)
77+
resultant_dataframe = pd.DataFrame(self.metric_results)
8478
resultant_dataframe.to_csv(path, index=False)

llmtune/qa/qa_metrics.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Union
3+
4+
import nltk
5+
import numpy as np
6+
import torch
7+
from langchain.evaluation import JsonValidityEvaluator
8+
from nltk import pos_tag
9+
from nltk.corpus import stopwords
10+
from nltk.tokenize import word_tokenize
11+
from rouge_score import rouge_scorer
12+
from transformers import DistilBertModel, DistilBertTokenizer
13+
14+
15+
json_validity_evaluator = JsonValidityEvaluator()
16+
17+
nltk.download("stopwords")
18+
nltk.download("punkt")
19+
nltk.download("averaged_perceptron_tagger")
20+
21+
22+
class LLMQaMetric(ABC):
23+
"""
24+
Abstract base class for a metric. A metric can be computed over a single
25+
data instance, and outputs a scalar value (integer or float).
26+
"""
27+
@property
28+
@abstractmethod
29+
def metric_name(self) -> str:
30+
pass
31+
32+
@abstractmethod
33+
def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int]:
34+
pass
35+
36+
37+
class QaMetricRegistry:
38+
registry = {}
39+
40+
@classmethod
41+
def register(cls, *names):
42+
def inner_wrapper(wrapped_class):
43+
for name in names:
44+
cls.registry[name] = wrapped_class
45+
return wrapped_class
46+
47+
return inner_wrapper
48+
49+
@classmethod
50+
def create_tests_from_list(cls, metric_names: List[str]) -> List[LLMQaMetric]:
51+
return [cls.registry[test]() for test in metric_names]
52+
53+
54+
@QaMetricRegistry.register("summary_length")
55+
class LengthMetric(LLMQaMetric):
56+
@property
57+
def metric_name(self) -> str:
58+
return "summary_length"
59+
60+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
61+
return abs(len(ground_truth) - len(model_prediction))
62+
63+
64+
@QaMetricRegistry.register("jaccard_similarity")
65+
class JaccardSimilarityMetric(LLMQaMetric):
66+
@property
67+
def metric_name(self) -> str:
68+
return "jaccard_similarity"
69+
70+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
71+
set_ground_truth = set(ground_truth.lower())
72+
set_model_prediction = set(model_prediction.lower())
73+
74+
intersection_size = len(set_ground_truth.intersection(set_model_prediction))
75+
union_size = len(set_ground_truth.union(set_model_prediction))
76+
77+
similarity = intersection_size / union_size if union_size != 0 else 0
78+
return float(similarity)
79+
80+
81+
@QaMetricRegistry.register("dot_product")
82+
class DotProductSimilarityMetric(LLMQaMetric):
83+
"""Encodes both the ground truth and model prediction using DistilBERT, and
84+
computes the dot product similarity between the two embeddings."""
85+
def __init__(self):
86+
model_name = "distilbert-base-uncased"
87+
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
88+
self.model = DistilBertModel.from_pretrained(model_name)
89+
90+
@property
91+
def metric_name(self) -> str:
92+
return "dot_product"
93+
94+
def _encode_sentence(self, sentence):
95+
tokens = self.tokenizer(sentence, return_tensors="pt")
96+
with torch.no_grad():
97+
outputs = self.model(**tokens)
98+
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
99+
100+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
101+
embedding_ground_truth = self._encode_sentence(ground_truth)
102+
embedding_model_prediction = self._encode_sentence(model_prediction)
103+
dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction)
104+
return float(dot_product_similarity)
105+
106+
107+
@QaMetricRegistry.register("rouge_score")
108+
class RougeScoreMetric(LLMQaMetric):
109+
@property
110+
def metric_name(self) -> str:
111+
return "rouge_score"
112+
113+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
114+
scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
115+
scores = scorer.score(model_prediction, ground_truth)
116+
return float(scores["rouge1"].precision)
117+
118+
119+
@QaMetricRegistry.register("word_overlap")
120+
class WordOverlapMetric(LLMQaMetric):
121+
@property
122+
def metric_name(self) -> str:
123+
return "word_overlap"
124+
125+
def _remove_stopwords(self, text: str) -> str:
126+
stop_words = set(stopwords.words("english"))
127+
word_tokens = word_tokenize(text)
128+
filtered_text = [word for word in word_tokens if word.lower() not in stop_words]
129+
return " ".join(filtered_text)
130+
131+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
132+
cleaned_model_prediction = self._remove_stopwords(model_prediction)
133+
cleaned_ground_truth = self._remove_stopwords(ground_truth)
134+
135+
words_model_prediction = set(cleaned_model_prediction.split())
136+
words_ground_truth = set(cleaned_ground_truth.split())
137+
138+
common_words = words_model_prediction.intersection(words_ground_truth)
139+
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
140+
return float(overlap_percentage)
141+
142+
143+
@QaMetricRegistry.register("json_valid")
144+
class JSONValidityMetric(LLMQaMetric):
145+
"""
146+
Checks to see if valid json can be parsed from the model output, according
147+
to langchain_core.utils.json.parse_json_markdown
148+
The JSON can be wrapped in markdown and this test will still pass
149+
"""
150+
151+
@property
152+
def metric_name(self) -> str:
153+
return "json_valid"
154+
155+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
156+
result = json_validity_evaluator.evaluate_strings(prediction=model_prediction)
157+
binary_res = result["score"]
158+
return float(binary_res)
159+
160+
161+
class PosCompositionMetric(LLMQaMetric):
162+
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
163+
words = word_tokenize(text)
164+
tags = pos_tag(words)
165+
pos_words = [word for word, tag in tags if tag in pos_tags]
166+
total_words = len(text.split(" "))
167+
return round(len(pos_words) / total_words, 2)
168+
169+
170+
@QaMetricRegistry.register("verb_percent")
171+
class VerbPercentMetric(PosCompositionMetric):
172+
@property
173+
def metric_name(self) -> str:
174+
return "verb_percent"
175+
176+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
177+
return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"])
178+
179+
180+
@QaMetricRegistry.register("adjective_percent")
181+
class AdjectivePercentMetric(PosCompositionMetric):
182+
@property
183+
def metric_name(self) -> str:
184+
return "adjective_percent"
185+
186+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
187+
return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"])
188+
189+
190+
@QaMetricRegistry.register("noun_percent")
191+
class NounPercentMetric(PosCompositionMetric):
192+
@property
193+
def metric_name(self) -> str:
194+
return "noun_percent"
195+
196+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
197+
return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"])
198+
199+
200+
# Instantiate tests
201+
# length_test = LengthMetric()
202+
# jaccard_similarity_test = JaccardSimilarityMetric()
203+
# dot_product_similarity_test = DotProductSimilarityMetric()
204+
# rouge_score_test = RougeScoreMetric()
205+
# word_overlap_test = WordOverlapMetric()
206+
# verb_percent_test = VerbPercentMetric()
207+
# adjective_percent_test = AdjectivePercentMetric()
208+
# noun_percent_test = NounPercentMetric()

0 commit comments

Comments
 (0)