Skip to content

Commit bf691d8

Browse files
authored
Merge pull request #178 from SinclairHudson/json-test
Adding JSON validity test
2 parents efaa6e9 + 665fd29 commit bf691d8

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# experiment files
44
*/experiments
55
*/experiment
6+
experiment/*
67
*/archive
78
*/backup
89
*/baseline_results
@@ -49,4 +50,4 @@ venv.bak/
4950

5051
# Coverage Report
5152
.coverage
52-
/htmlcov
53+
/htmlcov

llmtune/qa/qa_tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import nltk
44
import numpy as np
55
import torch
6+
from langchain.evaluation import JsonValidityEvaluator
67
from nltk import pos_tag
78
from nltk.corpus import stopwords
89
from nltk.tokenize import word_tokenize
@@ -12,6 +13,7 @@
1213
from llmtune.qa.generics import LLMQaTest
1314

1415

16+
json_validity_evaluator = JsonValidityEvaluator()
1517
model_name = "distilbert-base-uncased"
1618
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
1719
model = DistilBertModel.from_pretrained(model_name)
@@ -120,6 +122,24 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
120122
return float(overlap_percentage)
121123

122124

125+
@QaTestRegistry.register("json_valid")
126+
class JSONValidityTest(LLMQaTest):
127+
"""
128+
Checks to see if valid json can be parsed from the model output, according
129+
to langchain_core.utils.json.parse_json_markdown
130+
The JSON can be wrapped in markdown and this test will still pass
131+
"""
132+
133+
@property
134+
def test_name(self) -> str:
135+
return "json_valid"
136+
137+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
138+
result = json_validity_evaluator.evaluate_strings(prediction=model_prediction)
139+
binary_res = result["score"]
140+
return float(binary_res)
141+
142+
123143
class PosCompositionTest(LLMQaTest):
124144
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
125145
words = word_tokenize(text)

tests/qa/test_qa_tests.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
AdjectivePercent,
55
DotProductSimilarityTest,
66
JaccardSimilarityTest,
7+
JSONValidityTest,
78
LengthTest,
89
NounPercent,
910
RougeScoreTest,
@@ -23,6 +24,7 @@
2324
(VerbPercent, float),
2425
(AdjectivePercent, float),
2526
(NounPercent, float),
27+
(JSONValidityTest, float),
2628
],
2729
)
2830
def test_metric_return_type(test_class, expected_type):
@@ -84,3 +86,20 @@ def test_noun_percent():
8486
test = NounPercent()
8587
result = test.get_metric("prompt", "The cat", "The cat and the dog")
8688
assert result >= 0, "Noun percentage should be non-negative."
89+
90+
91+
@pytest.mark.parametrize(
92+
"input_string,expected_value",
93+
[
94+
('{"Answer": "The cat"}', 1),
95+
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
96+
('{"Answer": "The cat",}', 0),
97+
('{"Answer": "The cat", "test": "case"}', 1),
98+
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed
99+
('Here is an example of a JSON block: {"Answer": "The cat"}', 0),
100+
],
101+
)
102+
def test_json_valid(input_string: str, expected_value: float):
103+
test = JSONValidityTest()
104+
result = test.get_metric("prompt", "The cat", input_string)
105+
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."

0 commit comments

Comments
 (0)