Skip to content

Commit d5b7843

Browse files
switching test_qa_metrics and test_qa_tests
1 parent 31924fb commit d5b7843

File tree

2 files changed

+100
-100
lines changed

2 files changed

+100
-100
lines changed

tests/qa/test_qa_metrics.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,105 @@
11
import pytest
22

3-
from llmtune.qa.qa_tests import (
4-
JSONValidityTest,
3+
from llmtune.qa.qa_metrics import (
4+
AdjectivePercentMetric,
5+
DotProductSimilarityMetric,
6+
JaccardSimilarityMetric,
7+
JSONValidityMetric,
8+
LengthMetric,
9+
NounPercentMetric,
10+
RougeScoreMetric,
11+
VerbPercentMetric,
12+
WordOverlapMetric,
513
)
614

715

816
@pytest.mark.parametrize(
9-
"test_class",
17+
"test_class,expected_type",
1018
[
11-
JSONValidityTest,
19+
(LengthMetric, int),
20+
(JaccardSimilarityMetric, float),
21+
(DotProductSimilarityMetric, float),
22+
(RougeScoreMetric, float),
23+
(WordOverlapMetric, float),
24+
(VerbPercentMetric, float),
25+
(AdjectivePercentMetric, float),
26+
(NounPercentMetric, float),
27+
(JSONValidityMetric, float),
1228
],
1329
)
14-
def test_test_return_bool(test_class):
15-
"""Test to ensure that all tests return pass/fail boolean value."""
30+
def test_metric_return_type(test_class, expected_type):
1631
test_instance = test_class()
1732
prompt = "This is a test prompt."
1833
ground_truth = "This is a ground truth sentence."
1934
model_prediction = "This is a model predicted sentence."
2035

21-
metric_result = test_instance.test(prompt, ground_truth, model_prediction)
22-
assert isinstance(metric_result, bool), f"Expected return type bool, but got {type(metric_result)}."
36+
# Depending on the test class, the output could be different.
37+
metric_result = test_instance.get_metric(prompt, ground_truth, model_prediction)
38+
assert isinstance(
39+
metric_result, expected_type
40+
), f"Expected return type {expected_type}, but got {type(metric_result)}."
41+
42+
43+
def test_length_metric():
44+
test = LengthMetric()
45+
result = test.get_metric("prompt", "short text", "longer text")
46+
assert result == 1, "Length difference should be 1."
47+
48+
49+
def test_jaccard_similarity_metric():
50+
test = JaccardSimilarityMetric()
51+
result = test.get_metric("prompt", "hello world", "world hello")
52+
assert result == 1.0, "Jaccard similarity should be 1.0 for the same words in different orders."
53+
54+
55+
def test_dot_product_similarity_metric():
56+
test = DotProductSimilarityMetric()
57+
result = test.get_metric("prompt", "data", "data")
58+
assert result >= 0, "Dot product similarity should be non-negative."
59+
60+
61+
def test_rouge_score_metric():
62+
test = RougeScoreMetric()
63+
result = test.get_metric("prompt", "the quick brown fox", "the quick brown fox jumps over the lazy dog")
64+
assert result >= 0, "ROUGE precision should be non-negative."
65+
66+
67+
def test_word_overlap_metric():
68+
test = WordOverlapMetric()
69+
result = test.get_metric("prompt", "jump over the moon", "jump around the sun")
70+
assert result >= 0, "Word overlap percentage should be non-negative."
71+
72+
73+
def test_verb_percent_metric():
74+
test = VerbPercentMetric()
75+
result = test.get_metric("prompt", "He eats", "He is eating")
76+
assert result >= 0, "Verb percentage should be non-negative."
77+
78+
79+
def test_adjective_percent_metric():
80+
test = AdjectivePercentMetric()
81+
result = test.get_metric("prompt", "It is beautiful", "It is extremely beautiful")
82+
assert result >= 0, "Adjective percentage should be non-negative."
83+
84+
85+
def test_noun_percent_metric():
86+
test = NounPercentMetric()
87+
result = test.get_metric("prompt", "The cat", "The cat and the dog")
88+
assert result >= 0, "Noun percentage should be non-negative."
2389

2490

2591
@pytest.mark.parametrize(
2692
"input_string,expected_value",
2793
[
28-
('{"Answer": "The cat"}', True),
29-
("{'Answer': 'The cat'}", False), # Double quotes are required in json
30-
('{"Answer": "The cat",}', False), # Trailing comma is not allowed
31-
('{"Answer": "The cat", "test": "case"}', True),
32-
('```json\n{"Answer": "The cat"}\n```', True), # this json block can still be processed
33-
('Here is an example of a JSON block: {"Answer": "The cat"}', False),
94+
('{"Answer": "The cat"}', 1),
95+
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
96+
('{"Answer": "The cat",}', 0),
97+
('{"Answer": "The cat", "test": "case"}', 1),
98+
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed
99+
('Here is an example of a JSON block: {"Answer": "The cat"}', 0),
34100
],
35101
)
36-
def test_json_valid_metric(input_string: str, expected_value: bool):
37-
test = JSONValidityTest()
38-
result = test.test("prompt", "The cat", input_string)
102+
def test_json_valid_metric(input_string: str, expected_value: float):
103+
test = JSONValidityMetric()
104+
result = test.get_metric("prompt", "The cat", input_string)
39105
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."

tests/qa/test_qa_tests.py

Lines changed: 17 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,39 @@
11
import pytest
22

3-
from llmtune.qa.qa_metrics import (
4-
AdjectivePercentMetric,
5-
DotProductSimilarityMetric,
6-
JaccardSimilarityMetric,
7-
JSONValidityMetric,
8-
LengthMetric,
9-
NounPercentMetric,
10-
RougeScoreMetric,
11-
VerbPercentMetric,
12-
WordOverlapMetric,
3+
from llmtune.qa.qa_tests import (
4+
JSONValidityTest,
135
)
146

157

168
@pytest.mark.parametrize(
17-
"test_class,expected_type",
9+
"test_class",
1810
[
19-
(LengthMetric, int),
20-
(JaccardSimilarityMetric, float),
21-
(DotProductSimilarityMetric, float),
22-
(RougeScoreMetric, float),
23-
(WordOverlapMetric, float),
24-
(VerbPercentMetric, float),
25-
(AdjectivePercentMetric, float),
26-
(NounPercentMetric, float),
27-
(JSONValidityMetric, float),
11+
JSONValidityTest,
2812
],
2913
)
30-
def test_metric_return_type(test_class, expected_type):
14+
def test_test_return_bool(test_class):
15+
"""Test to ensure that all tests return pass/fail boolean value."""
3116
test_instance = test_class()
3217
prompt = "This is a test prompt."
3318
ground_truth = "This is a ground truth sentence."
3419
model_prediction = "This is a model predicted sentence."
3520

36-
# Depending on the test class, the output could be different.
37-
metric_result = test_instance.get_metric(prompt, ground_truth, model_prediction)
38-
assert isinstance(
39-
metric_result, expected_type
40-
), f"Expected return type {expected_type}, but got {type(metric_result)}."
41-
42-
43-
def test_length_metric():
44-
test = LengthMetric()
45-
result = test.get_metric("prompt", "short text", "longer text")
46-
assert result == 1, "Length difference should be 1."
47-
48-
49-
def test_jaccard_similarity_metric():
50-
test = JaccardSimilarityMetric()
51-
result = test.get_metric("prompt", "hello world", "world hello")
52-
assert result == 1.0, "Jaccard similarity should be 1.0 for the same words in different orders."
53-
54-
55-
def test_dot_product_similarity_metric():
56-
test = DotProductSimilarityMetric()
57-
result = test.get_metric("prompt", "data", "data")
58-
assert result >= 0, "Dot product similarity should be non-negative."
59-
60-
61-
def test_rouge_score_metric():
62-
test = RougeScoreMetric()
63-
result = test.get_metric("prompt", "the quick brown fox", "the quick brown fox jumps over the lazy dog")
64-
assert result >= 0, "ROUGE precision should be non-negative."
65-
66-
67-
def test_word_overlap_metric():
68-
test = WordOverlapMetric()
69-
result = test.get_metric("prompt", "jump over the moon", "jump around the sun")
70-
assert result >= 0, "Word overlap percentage should be non-negative."
71-
72-
73-
def test_verb_percent_metric():
74-
test = VerbPercentMetric()
75-
result = test.get_metric("prompt", "He eats", "He is eating")
76-
assert result >= 0, "Verb percentage should be non-negative."
77-
78-
79-
def test_adjective_percent_metric():
80-
test = AdjectivePercentMetric()
81-
result = test.get_metric("prompt", "It is beautiful", "It is extremely beautiful")
82-
assert result >= 0, "Adjective percentage should be non-negative."
83-
84-
85-
def test_noun_percent_metric():
86-
test = NounPercentMetric()
87-
result = test.get_metric("prompt", "The cat", "The cat and the dog")
88-
assert result >= 0, "Noun percentage should be non-negative."
21+
metric_result = test_instance.test(prompt, ground_truth, model_prediction)
22+
assert isinstance(metric_result, bool), f"Expected return type bool, but got {type(metric_result)}."
8923

9024

9125
@pytest.mark.parametrize(
9226
"input_string,expected_value",
9327
[
94-
('{"Answer": "The cat"}', 1),
95-
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
96-
('{"Answer": "The cat",}', 0),
97-
('{"Answer": "The cat", "test": "case"}', 1),
98-
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed
99-
('Here is an example of a JSON block: {"Answer": "The cat"}', 0),
28+
('{"Answer": "The cat"}', True),
29+
("{'Answer': 'The cat'}", False), # Double quotes are required in json
30+
('{"Answer": "The cat",}', False), # Trailing comma is not allowed
31+
('{"Answer": "The cat", "test": "case"}', True),
32+
('```json\n{"Answer": "The cat"}\n```', True), # this json block can still be processed
33+
('Here is an example of a JSON block: {"Answer": "The cat"}', False),
10034
],
10135
)
102-
def test_json_valid_metric(input_string: str, expected_value: float):
103-
test = JSONValidityMetric()
104-
result = test.get_metric("prompt", "The cat", input_string)
36+
def test_json_valid_metric(input_string: str, expected_value: bool):
37+
test = JSONValidityTest()
38+
result = test.test("prompt", "The cat", input_string)
10539
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."

0 commit comments

Comments
 (0)