Skip to content

Commit 31924fb

Browse files
adding unit test for tests
1 parent 63a7b30 commit 31924fb

File tree

5 files changed

+129
-87
lines changed

5 files changed

+129
-87
lines changed

llmtune/qa/generics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import statistics
2-
from abc import ABC, abstractmethod
32
from pathlib import Path
43
from typing import Dict, List, Union
54

@@ -14,6 +13,7 @@ class LLMMetricSuite:
1413
Represents and runs a suite of metrics on a set of prompts,
1514
golden responses, and model predictions.
1615
"""
16+
1717
def __init__(
1818
self,
1919
metrics: List[LLMQaMetric],

llmtune/qa/qa_metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class LLMQaMetric(ABC):
2424
Abstract base class for a metric. A metric can be computed over a single
2525
data instance, and outputs a scalar value (integer or float).
2626
"""
27+
2728
@property
2829
@abstractmethod
2930
def metric_name(self) -> str:
@@ -82,6 +83,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
8283
class DotProductSimilarityMetric(LLMQaMetric):
8384
"""Encodes both the ground truth and model prediction using DistilBERT, and
8485
computes the dot product similarity between the two embeddings."""
86+
8587
def __init__(self):
8688
model_name = "distilbert-base-uncased"
8789
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)

llmtune/qa/qa_tests.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Union
32

43
from langchain.evaluation import JsonValidityEvaluator
54

@@ -9,6 +8,7 @@ class LLMQaTest(ABC):
98
Abstract base class for a test. A test can be computed over a single
109
data instance/llm response, and outputs a boolean value (pass or fail).
1110
"""
11+
1212
@property
1313
@abstractmethod
1414
def test_name(self) -> str:
@@ -25,14 +25,15 @@ class JSONValidityTest(LLMQaTest):
2525
to langchain_core.utils.json.parse_json_markdown
2626
The JSON can be wrapped in markdown and this test will still pass
2727
"""
28+
2829
def __init__(self):
2930
self.json_validity_evaluator = JsonValidityEvaluator()
3031

3132
@property
3233
def test_name(self) -> str:
3334
return "json_valid"
3435

35-
def get_metric(self, model_prediction: str) -> bool:
36-
result = self.json_validity_evaluator.evaluate_strings(prediction=model_prediction)
36+
def test(self, prompt: str, grount_truth: str, model_pred: str) -> bool:
37+
result = self.json_validity_evaluator.evaluate_strings(prediction=model_pred)
3738
binary_res = result["score"]
3839
return bool(binary_res)

tests/qa/test_qa_metrics.py

Lines changed: 17 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,39 @@
11
import pytest
22

3-
from llmtune.qa.qa_metrics import (
4-
AdjectivePercentMetric,
5-
DotProductSimilarityMetric,
6-
JaccardSimilarityMetric,
7-
JSONValidityMetric,
8-
LengthMetric,
9-
NounPercentMetric,
10-
RougeScoreMetric,
11-
VerbPercentMetric,
12-
WordOverlapMetric,
3+
from llmtune.qa.qa_tests import (
4+
JSONValidityTest,
135
)
146

157

168
@pytest.mark.parametrize(
17-
"test_class,expected_type",
9+
"test_class",
1810
[
19-
(LengthMetric, int),
20-
(JaccardSimilarityMetric, float),
21-
(DotProductSimilarityMetric, float),
22-
(RougeScoreMetric, float),
23-
(WordOverlapMetric, float),
24-
(VerbPercentMetric, float),
25-
(AdjectivePercentMetric, float),
26-
(NounPercentMetric, float),
27-
(JSONValidityMetric, float),
11+
JSONValidityTest,
2812
],
2913
)
30-
def test_metric_return_type(test_class, expected_type):
14+
def test_test_return_bool(test_class):
15+
"""Test to ensure that all tests return pass/fail boolean value."""
3116
test_instance = test_class()
3217
prompt = "This is a test prompt."
3318
ground_truth = "This is a ground truth sentence."
3419
model_prediction = "This is a model predicted sentence."
3520

36-
# Depending on the test class, the output could be different.
37-
metric_result = test_instance.get_metric(prompt, ground_truth, model_prediction)
38-
assert isinstance(
39-
metric_result, expected_type
40-
), f"Expected return type {expected_type}, but got {type(metric_result)}."
41-
42-
43-
def test_length_metric():
44-
test = LengthMetric()
45-
result = test.get_metric("prompt", "short text", "longer text")
46-
assert result == 1, "Length difference should be 1."
47-
48-
49-
def test_jaccard_similarity_metric():
50-
test = JaccardSimilarityMetric()
51-
result = test.get_metric("prompt", "hello world", "world hello")
52-
assert result == 1.0, "Jaccard similarity should be 1.0 for the same words in different orders."
53-
54-
55-
def test_dot_product_similarity_metric():
56-
test = DotProductSimilarityMetric()
57-
result = test.get_metric("prompt", "data", "data")
58-
assert result >= 0, "Dot product similarity should be non-negative."
59-
60-
61-
def test_rouge_score_metric():
62-
test = RougeScoreMetric()
63-
result = test.get_metric("prompt", "the quick brown fox", "the quick brown fox jumps over the lazy dog")
64-
assert result >= 0, "ROUGE precision should be non-negative."
65-
66-
67-
def test_word_overlap_metric():
68-
test = WordOverlapMetric()
69-
result = test.get_metric("prompt", "jump over the moon", "jump around the sun")
70-
assert result >= 0, "Word overlap percentage should be non-negative."
71-
72-
73-
def test_verb_percent_metric():
74-
test = VerbPercentMetric()
75-
result = test.get_metric("prompt", "He eats", "He is eating")
76-
assert result >= 0, "Verb percentage should be non-negative."
77-
78-
79-
def test_adjective_percent_metric():
80-
test = AdjectivePercentMetric()
81-
result = test.get_metric("prompt", "It is beautiful", "It is extremely beautiful")
82-
assert result >= 0, "Adjective percentage should be non-negative."
83-
84-
85-
def test_noun_percent_metric():
86-
test = NounPercentMetric()
87-
result = test.get_metric("prompt", "The cat", "The cat and the dog")
88-
assert result >= 0, "Noun percentage should be non-negative."
21+
metric_result = test_instance.test(prompt, ground_truth, model_prediction)
22+
assert isinstance(metric_result, bool), f"Expected return type bool, but got {type(metric_result)}."
8923

9024

9125
@pytest.mark.parametrize(
9226
"input_string,expected_value",
9327
[
94-
('{"Answer": "The cat"}', 1),
95-
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
96-
('{"Answer": "The cat",}', 0),
97-
('{"Answer": "The cat", "test": "case"}', 1),
98-
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed
99-
('Here is an example of a JSON block: {"Answer": "The cat"}', 0),
28+
('{"Answer": "The cat"}', True),
29+
("{'Answer': 'The cat'}", False), # Double quotes are required in json
30+
('{"Answer": "The cat",}', False), # Trailing comma is not allowed
31+
('{"Answer": "The cat", "test": "case"}', True),
32+
('```json\n{"Answer": "The cat"}\n```', True), # this json block can still be processed
33+
('Here is an example of a JSON block: {"Answer": "The cat"}', False),
10034
],
10135
)
102-
def test_json_valid_metric(input_string: str, expected_value: float):
103-
test = JSONValidityMetric()
104-
result = test.get_metric("prompt", "The cat", input_string)
36+
def test_json_valid_metric(input_string: str, expected_value: bool):
37+
test = JSONValidityTest()
38+
result = test.test("prompt", "The cat", input_string)
10539
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."

tests/qa/test_qa_tests.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import pytest
2+
3+
from llmtune.qa.qa_metrics import (
4+
AdjectivePercentMetric,
5+
DotProductSimilarityMetric,
6+
JaccardSimilarityMetric,
7+
JSONValidityMetric,
8+
LengthMetric,
9+
NounPercentMetric,
10+
RougeScoreMetric,
11+
VerbPercentMetric,
12+
WordOverlapMetric,
13+
)
14+
15+
16+
@pytest.mark.parametrize(
17+
"test_class,expected_type",
18+
[
19+
(LengthMetric, int),
20+
(JaccardSimilarityMetric, float),
21+
(DotProductSimilarityMetric, float),
22+
(RougeScoreMetric, float),
23+
(WordOverlapMetric, float),
24+
(VerbPercentMetric, float),
25+
(AdjectivePercentMetric, float),
26+
(NounPercentMetric, float),
27+
(JSONValidityMetric, float),
28+
],
29+
)
30+
def test_metric_return_type(test_class, expected_type):
31+
test_instance = test_class()
32+
prompt = "This is a test prompt."
33+
ground_truth = "This is a ground truth sentence."
34+
model_prediction = "This is a model predicted sentence."
35+
36+
# Depending on the test class, the output could be different.
37+
metric_result = test_instance.get_metric(prompt, ground_truth, model_prediction)
38+
assert isinstance(
39+
metric_result, expected_type
40+
), f"Expected return type {expected_type}, but got {type(metric_result)}."
41+
42+
43+
def test_length_metric():
44+
test = LengthMetric()
45+
result = test.get_metric("prompt", "short text", "longer text")
46+
assert result == 1, "Length difference should be 1."
47+
48+
49+
def test_jaccard_similarity_metric():
50+
test = JaccardSimilarityMetric()
51+
result = test.get_metric("prompt", "hello world", "world hello")
52+
assert result == 1.0, "Jaccard similarity should be 1.0 for the same words in different orders."
53+
54+
55+
def test_dot_product_similarity_metric():
56+
test = DotProductSimilarityMetric()
57+
result = test.get_metric("prompt", "data", "data")
58+
assert result >= 0, "Dot product similarity should be non-negative."
59+
60+
61+
def test_rouge_score_metric():
62+
test = RougeScoreMetric()
63+
result = test.get_metric("prompt", "the quick brown fox", "the quick brown fox jumps over the lazy dog")
64+
assert result >= 0, "ROUGE precision should be non-negative."
65+
66+
67+
def test_word_overlap_metric():
68+
test = WordOverlapMetric()
69+
result = test.get_metric("prompt", "jump over the moon", "jump around the sun")
70+
assert result >= 0, "Word overlap percentage should be non-negative."
71+
72+
73+
def test_verb_percent_metric():
74+
test = VerbPercentMetric()
75+
result = test.get_metric("prompt", "He eats", "He is eating")
76+
assert result >= 0, "Verb percentage should be non-negative."
77+
78+
79+
def test_adjective_percent_metric():
80+
test = AdjectivePercentMetric()
81+
result = test.get_metric("prompt", "It is beautiful", "It is extremely beautiful")
82+
assert result >= 0, "Adjective percentage should be non-negative."
83+
84+
85+
def test_noun_percent_metric():
86+
test = NounPercentMetric()
87+
result = test.get_metric("prompt", "The cat", "The cat and the dog")
88+
assert result >= 0, "Noun percentage should be non-negative."
89+
90+
91+
@pytest.mark.parametrize(
92+
"input_string,expected_value",
93+
[
94+
('{"Answer": "The cat"}', 1),
95+
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
96+
('{"Answer": "The cat",}', 0),
97+
('{"Answer": "The cat", "test": "case"}', 1),
98+
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed
99+
('Here is an example of a JSON block: {"Answer": "The cat"}', 0),
100+
],
101+
)
102+
def test_json_valid_metric(input_string: str, expected_value: float):
103+
test = JSONValidityMetric()
104+
result = test.get_metric("prompt", "The cat", input_string)
105+
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."

0 commit comments

Comments
 (0)