Skip to content

Commit 665fd29

Browse files
running linter
1 parent c56c534 commit 665fd29

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

llmtune/qa/qa_tests.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import nltk
44
import numpy as np
55
import torch
6+
from langchain.evaluation import JsonValidityEvaluator
67
from nltk import pos_tag
78
from nltk.corpus import stopwords
89
from nltk.tokenize import word_tokenize
910
from rouge_score import rouge_scorer
1011
from transformers import DistilBertModel, DistilBertTokenizer
11-
from langchain.evaluation import JsonValidityEvaluator
1212

1313
from llmtune.qa.generics import LLMQaTest
1414

@@ -121,13 +121,15 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
121121
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
122122
return float(overlap_percentage)
123123

124+
124125
@QaTestRegistry.register("json_valid")
125126
class JSONValidityTest(LLMQaTest):
126127
"""
127128
Checks to see if valid json can be parsed from the model output, according
128129
to langchain_core.utils.json.parse_json_markdown
129130
The JSON can be wrapped in markdown and this test will still pass
130131
"""
132+
131133
@property
132134
def test_name(self) -> str:
133135
return "json_valid"
@@ -137,6 +139,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
137139
binary_res = result["score"]
138140
return float(binary_res)
139141

142+
140143
class PosCompositionTest(LLMQaTest):
141144
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
142145
words = word_tokenize(text)

tests/qa/test_qa_tests.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
AdjectivePercent,
55
DotProductSimilarityTest,
66
JaccardSimilarityTest,
7+
JSONValidityTest,
78
LengthTest,
89
NounPercent,
910
RougeScoreTest,
1011
VerbPercent,
1112
WordOverlapTest,
12-
JSONValidityTest
1313
)
1414

1515

@@ -87,11 +87,12 @@ def test_noun_percent():
8787
result = test.get_metric("prompt", "The cat", "The cat and the dog")
8888
assert result >= 0, "Noun percentage should be non-negative."
8989

90+
9091
@pytest.mark.parametrize(
9192
"input_string,expected_value",
9293
[
9394
('{"Answer": "The cat"}', 1),
94-
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
95+
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
9596
('{"Answer": "The cat",}', 0),
9697
('{"Answer": "The cat", "test": "case"}', 1),
9798
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed

0 commit comments

Comments
 (0)