|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from llmtune.qa.qa_tests import ( |
4 | | - JSONValidityTest, |
| 3 | +from llmtune.qa.qa_metrics import ( |
| 4 | + AdjectivePercentMetric, |
| 5 | + DotProductSimilarityMetric, |
| 6 | + JaccardSimilarityMetric, |
| 7 | + JSONValidityMetric, |
| 8 | + LengthMetric, |
| 9 | + NounPercentMetric, |
| 10 | + RougeScoreMetric, |
| 11 | + VerbPercentMetric, |
| 12 | + WordOverlapMetric, |
5 | 13 | ) |
6 | 14 |
|
7 | 15 |
|
8 | 16 | @pytest.mark.parametrize( |
9 | | - "test_class", |
| 17 | + "test_class,expected_type", |
10 | 18 | [ |
11 | | - JSONValidityTest, |
| 19 | + (LengthMetric, int), |
| 20 | + (JaccardSimilarityMetric, float), |
| 21 | + (DotProductSimilarityMetric, float), |
| 22 | + (RougeScoreMetric, float), |
| 23 | + (WordOverlapMetric, float), |
| 24 | + (VerbPercentMetric, float), |
| 25 | + (AdjectivePercentMetric, float), |
| 26 | + (NounPercentMetric, float), |
| 27 | + (JSONValidityMetric, float), |
12 | 28 | ], |
13 | 29 | ) |
14 | | -def test_test_return_bool(test_class): |
15 | | - """Test to ensure that all tests return pass/fail boolean value.""" |
| 30 | +def test_metric_return_type(test_class, expected_type): |
16 | 31 | test_instance = test_class() |
17 | 32 | prompt = "This is a test prompt." |
18 | 33 | ground_truth = "This is a ground truth sentence." |
19 | 34 | model_prediction = "This is a model predicted sentence." |
20 | 35 |
|
21 | | - metric_result = test_instance.test(prompt, ground_truth, model_prediction) |
22 | | - assert isinstance(metric_result, bool), f"Expected return type bool, but got {type(metric_result)}." |
| 36 | + # Depending on the test class, the output could be different. |
| 37 | + metric_result = test_instance.get_metric(prompt, ground_truth, model_prediction) |
| 38 | + assert isinstance( |
| 39 | + metric_result, expected_type |
| 40 | + ), f"Expected return type {expected_type}, but got {type(metric_result)}." |
| 41 | + |
| 42 | + |
| 43 | +def test_length_metric(): |
| 44 | + test = LengthMetric() |
| 45 | + result = test.get_metric("prompt", "short text", "longer text") |
| 46 | + assert result == 1, "Length difference should be 1." |
| 47 | + |
| 48 | + |
| 49 | +def test_jaccard_similarity_metric(): |
| 50 | + test = JaccardSimilarityMetric() |
| 51 | + result = test.get_metric("prompt", "hello world", "world hello") |
| 52 | + assert result == 1.0, "Jaccard similarity should be 1.0 for the same words in different orders." |
| 53 | + |
| 54 | + |
| 55 | +def test_dot_product_similarity_metric(): |
| 56 | + test = DotProductSimilarityMetric() |
| 57 | + result = test.get_metric("prompt", "data", "data") |
| 58 | + assert result >= 0, "Dot product similarity should be non-negative." |
| 59 | + |
| 60 | + |
| 61 | +def test_rouge_score_metric(): |
| 62 | + test = RougeScoreMetric() |
| 63 | + result = test.get_metric("prompt", "the quick brown fox", "the quick brown fox jumps over the lazy dog") |
| 64 | + assert result >= 0, "ROUGE precision should be non-negative." |
| 65 | + |
| 66 | + |
| 67 | +def test_word_overlap_metric(): |
| 68 | + test = WordOverlapMetric() |
| 69 | + result = test.get_metric("prompt", "jump over the moon", "jump around the sun") |
| 70 | + assert result >= 0, "Word overlap percentage should be non-negative." |
| 71 | + |
| 72 | + |
| 73 | +def test_verb_percent_metric(): |
| 74 | + test = VerbPercentMetric() |
| 75 | + result = test.get_metric("prompt", "He eats", "He is eating") |
| 76 | + assert result >= 0, "Verb percentage should be non-negative." |
| 77 | + |
| 78 | + |
| 79 | +def test_adjective_percent_metric(): |
| 80 | + test = AdjectivePercentMetric() |
| 81 | + result = test.get_metric("prompt", "It is beautiful", "It is extremely beautiful") |
| 82 | + assert result >= 0, "Adjective percentage should be non-negative." |
| 83 | + |
| 84 | + |
| 85 | +def test_noun_percent_metric(): |
| 86 | + test = NounPercentMetric() |
| 87 | + result = test.get_metric("prompt", "The cat", "The cat and the dog") |
| 88 | + assert result >= 0, "Noun percentage should be non-negative." |
23 | 89 |
|
24 | 90 |
|
25 | 91 | @pytest.mark.parametrize( |
26 | 92 | "input_string,expected_value", |
27 | 93 | [ |
28 | | - ('{"Answer": "The cat"}', True), |
29 | | - ("{'Answer': 'The cat'}", False), # Double quotes are required in json |
30 | | - ('{"Answer": "The cat",}', False), # Trailing comma is not allowed |
31 | | - ('{"Answer": "The cat", "test": "case"}', True), |
32 | | - ('```json\n{"Answer": "The cat"}\n```', True), # this json block can still be processed |
33 | | - ('Here is an example of a JSON block: {"Answer": "The cat"}', False), |
| 94 | + ('{"Answer": "The cat"}', 1), |
| 95 | + ("{'Answer': 'The cat'}", 0), # Double quotes are required in json |
| 96 | + ('{"Answer": "The cat",}', 0), |
| 97 | + ('{"Answer": "The cat", "test": "case"}', 1), |
| 98 | + ('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed |
| 99 | + ('Here is an example of a JSON block: {"Answer": "The cat"}', 0), |
34 | 100 | ], |
35 | 101 | ) |
36 | | -def test_json_valid_metric(input_string: str, expected_value: bool): |
37 | | - test = JSONValidityTest() |
38 | | - result = test.test("prompt", "The cat", input_string) |
| 102 | +def test_json_valid_metric(input_string: str, expected_value: float): |
| 103 | + test = JSONValidityMetric() |
| 104 | + result = test.get_metric("prompt", "The cat", input_string) |
39 | 105 | assert result == expected_value, f"JSON validity should be {expected_value} but got {result}." |
0 commit comments