|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from llmtune.qa.qa_metrics import ( |
4 | | - AdjectivePercentMetric, |
5 | | - DotProductSimilarityMetric, |
6 | | - JaccardSimilarityMetric, |
7 | | - JSONValidityMetric, |
8 | | - LengthMetric, |
9 | | - NounPercentMetric, |
10 | | - RougeScoreMetric, |
11 | | - VerbPercentMetric, |
12 | | - WordOverlapMetric, |
| 3 | +from llmtune.qa.qa_tests import ( |
| 4 | + JSONValidityTest, |
13 | 5 | ) |
14 | 6 |
|
15 | 7 |
|
16 | 8 | @pytest.mark.parametrize( |
17 | | - "test_class,expected_type", |
| 9 | + "test_class", |
18 | 10 | [ |
19 | | - (LengthMetric, int), |
20 | | - (JaccardSimilarityMetric, float), |
21 | | - (DotProductSimilarityMetric, float), |
22 | | - (RougeScoreMetric, float), |
23 | | - (WordOverlapMetric, float), |
24 | | - (VerbPercentMetric, float), |
25 | | - (AdjectivePercentMetric, float), |
26 | | - (NounPercentMetric, float), |
27 | | - (JSONValidityMetric, float), |
| 11 | + JSONValidityTest, |
28 | 12 | ], |
29 | 13 | ) |
30 | | -def test_metric_return_type(test_class, expected_type): |
| 14 | +def test_test_return_bool(test_class): |
| 15 | + """Test to ensure that all tests return pass/fail boolean value.""" |
31 | 16 | test_instance = test_class() |
32 | 17 | prompt = "This is a test prompt." |
33 | 18 | ground_truth = "This is a ground truth sentence." |
34 | 19 | model_prediction = "This is a model predicted sentence." |
35 | 20 |
|
36 | | - # Depending on the test class, the output could be different. |
37 | | - metric_result = test_instance.get_metric(prompt, ground_truth, model_prediction) |
38 | | - assert isinstance( |
39 | | - metric_result, expected_type |
40 | | - ), f"Expected return type {expected_type}, but got {type(metric_result)}." |
41 | | - |
42 | | - |
43 | | -def test_length_metric(): |
44 | | - test = LengthMetric() |
45 | | - result = test.get_metric("prompt", "short text", "longer text") |
46 | | - assert result == 1, "Length difference should be 1." |
47 | | - |
48 | | - |
49 | | -def test_jaccard_similarity_metric(): |
50 | | - test = JaccardSimilarityMetric() |
51 | | - result = test.get_metric("prompt", "hello world", "world hello") |
52 | | - assert result == 1.0, "Jaccard similarity should be 1.0 for the same words in different orders." |
53 | | - |
54 | | - |
55 | | -def test_dot_product_similarity_metric(): |
56 | | - test = DotProductSimilarityMetric() |
57 | | - result = test.get_metric("prompt", "data", "data") |
58 | | - assert result >= 0, "Dot product similarity should be non-negative." |
59 | | - |
60 | | - |
61 | | -def test_rouge_score_metric(): |
62 | | - test = RougeScoreMetric() |
63 | | - result = test.get_metric("prompt", "the quick brown fox", "the quick brown fox jumps over the lazy dog") |
64 | | - assert result >= 0, "ROUGE precision should be non-negative." |
65 | | - |
66 | | - |
67 | | -def test_word_overlap_metric(): |
68 | | - test = WordOverlapMetric() |
69 | | - result = test.get_metric("prompt", "jump over the moon", "jump around the sun") |
70 | | - assert result >= 0, "Word overlap percentage should be non-negative." |
71 | | - |
72 | | - |
73 | | -def test_verb_percent_metric(): |
74 | | - test = VerbPercentMetric() |
75 | | - result = test.get_metric("prompt", "He eats", "He is eating") |
76 | | - assert result >= 0, "Verb percentage should be non-negative." |
77 | | - |
78 | | - |
79 | | -def test_adjective_percent_metric(): |
80 | | - test = AdjectivePercentMetric() |
81 | | - result = test.get_metric("prompt", "It is beautiful", "It is extremely beautiful") |
82 | | - assert result >= 0, "Adjective percentage should be non-negative." |
83 | | - |
84 | | - |
85 | | -def test_noun_percent_metric(): |
86 | | - test = NounPercentMetric() |
87 | | - result = test.get_metric("prompt", "The cat", "The cat and the dog") |
88 | | - assert result >= 0, "Noun percentage should be non-negative." |
| 21 | + metric_result = test_instance.test(prompt, ground_truth, model_prediction) |
| 22 | + assert isinstance(metric_result, bool), f"Expected return type bool, but got {type(metric_result)}." |
89 | 23 |
|
90 | 24 |
|
91 | 25 | @pytest.mark.parametrize( |
92 | 26 | "input_string,expected_value", |
93 | 27 | [ |
94 | | - ('{"Answer": "The cat"}', 1), |
95 | | - ("{'Answer': 'The cat'}", 0), # Double quotes are required in json |
96 | | - ('{"Answer": "The cat",}', 0), |
97 | | - ('{"Answer": "The cat", "test": "case"}', 1), |
98 | | - ('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed |
99 | | - ('Here is an example of a JSON block: {"Answer": "The cat"}', 0), |
| 28 | + ('{"Answer": "The cat"}', True), |
| 29 | + ("{'Answer': 'The cat'}", False), # Double quotes are required in json |
| 30 | + ('{"Answer": "The cat",}', False), # Trailing comma is not allowed |
| 31 | + ('{"Answer": "The cat", "test": "case"}', True), |
| 32 | + ('```json\n{"Answer": "The cat"}\n```', True), # this json block can still be processed |
| 33 | + ('Here is an example of a JSON block: {"Answer": "The cat"}', False), |
100 | 34 | ], |
101 | 35 | ) |
102 | | -def test_json_valid_metric(input_string: str, expected_value: float): |
103 | | - test = JSONValidityMetric() |
104 | | - result = test.get_metric("prompt", "The cat", input_string) |
| 36 | +def test_json_valid_metric(input_string: str, expected_value: bool): |
| 37 | + test = JSONValidityTest() |
| 38 | + result = test.test("prompt", "The cat", input_string) |
105 | 39 | assert result == expected_value, f"JSON validity should be {expected_value} but got {result}." |
0 commit comments