Skip to content

Commit c6c0032

Browse files
adding json validity test
1 parent 8e2f847 commit c6c0032

File tree

4 files changed

+114
-1
lines changed

4 files changed

+114
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# experiment files
44
*/experiments
55
*/experiment
6+
experiment/*
67
*/archive
78
*/backup
89
*/baseline_results
@@ -49,4 +50,4 @@ venv.bak/
4950

5051
# Coverage Report
5152
.coverage
52-
/htmlcov
53+
/htmlcov

examples/json_output_config.yml

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
save_dir: "./experiment/"
2+
3+
ablation:
4+
use_ablate: false
5+
6+
# Data Ingestion -------------------
7+
data:
8+
file_type: "huggingface" # one of 'json', 'csv', 'huggingface'
9+
path: "azizshaw/text_to_json"
10+
prompt:
11+
>- # prompt, make sure column inputs are enclosed in {} brackets and that they match your data
12+
{instruction}
13+
Now create a json object for the following scenario
14+
{input}
15+
prompt_stub:
16+
>- # Stub to add for training at the end of prompt, for test set or inference, this is omitted; make sure only one variable is present
17+
{output}
18+
test_size: 0.1 # Proportion of test as % of total; if integer then # of samples
19+
train_size: 0.9 # Proportion of train as % of total; if integer then # of samples
20+
train_test_split_seed: 42
21+
22+
# Model Definition -------------------
23+
model:
24+
hf_model_ckpt: "facebook/opt-125m"
25+
torch_dtype: "bfloat16"
26+
#attn_implementation: "flash_attention_2"
27+
quantize: true
28+
bitsandbytes:
29+
load_in_4bit: true
30+
bnb_4bit_compute_dtype: "bfloat16"
31+
bnb_4bit_quant_type: "nf4"
32+
33+
# LoRA Params -------------------
34+
lora:
35+
task_type: "CAUSAL_LM"
36+
r: 16
37+
lora_dropout: 0.1
38+
target_modules:
39+
- q_proj
40+
- v_proj
41+
- k_proj
42+
- o_proj
43+
- up_proj
44+
- down_proj
45+
- gate_proj
46+
47+
# Training -------------------
48+
training:
49+
training_args:
50+
num_train_epochs: 5
51+
per_device_train_batch_size: 4
52+
gradient_accumulation_steps: 4
53+
gradient_checkpointing: True
54+
optim: "paged_adamw_32bit"
55+
logging_steps: 100
56+
learning_rate: 2.0e-4
57+
bf16: true # Set to true for mixed precision training on Newer GPUs
58+
tf32: true
59+
# fp16: false # Set to true for mixed precision training on Older GPUs
60+
max_grad_norm: 0.3
61+
warmup_ratio: 0.03
62+
lr_scheduler_type: "constant"
63+
sft_args:
64+
max_seq_length: 5000
65+
# neftune_noise_alpha: None
66+
67+
inference:
68+
max_new_tokens: 1024
69+
use_cache: True
70+
do_sample: True
71+
top_p: 0.9
72+
temperature: 0.8
73+
74+
qa:
75+
llm_tests:
76+
- json_valid
77+
- jaccard_similarity

llmtune/qa/qa_tests.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from nltk.tokenize import word_tokenize
99
from rouge_score import rouge_scorer
1010
from transformers import DistilBertModel, DistilBertTokenizer
11+
from langchain.evaluation import JsonValidityEvaluator
1112

1213
from llmtune.qa.generics import LLMQaTest
1314

1415

16+
json_validity_evaluator = JsonValidityEvaluator()
1517
model_name = "distilbert-base-uncased"
1618
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
1719
model = DistilBertModel.from_pretrained(model_name)
@@ -119,6 +121,21 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
119121
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
120122
return float(overlap_percentage)
121123

124+
@QaTestRegistry.register("json_valid")
125+
class JSONValidityTest(LLMQaTest):
126+
"""
127+
Checks to see if valid json can be parsed from the model output, according
128+
to langchain_core.utils.json.parse_json_markdown
129+
The JSON can be wrapped in markdown and this test will still pass
130+
"""
131+
@property
132+
def test_name(self) -> str:
133+
return "json_valid"
134+
135+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
136+
result = json_validity_evaluator.evaluate_strings(prediction=model_prediction)
137+
binary_res = result["score"]
138+
return float(binary_res)
122139

123140
class PosCompositionTest(LLMQaTest):
124141
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:

tests/qa/test_qa_tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
RougeScoreTest,
1010
VerbPercent,
1111
WordOverlapTest,
12+
JSONValidityTest
1213
)
1314

1415

@@ -23,6 +24,7 @@
2324
(VerbPercent, float),
2425
(AdjectivePercent, float),
2526
(NounPercent, float),
27+
(JSONValidityTest, float),
2628
],
2729
)
2830
def test_metric_return_type(test_class, expected_type):
@@ -84,3 +86,19 @@ def test_noun_percent():
8486
test = NounPercent()
8587
result = test.get_metric("prompt", "The cat", "The cat and the dog")
8688
assert result >= 0, "Noun percentage should be non-negative."
89+
90+
@pytest.mark.parametrize(
91+
"input_string,expected_value",
92+
[
93+
('{"Answer": "The cat"}', 1),
94+
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
95+
('{"Answer": "The cat",}', 0),
96+
('{"Answer": "The cat", "test": "case"}', 1),
97+
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed
98+
('Here is an example of a JSON block: {"Answer": "The cat"}', 0),
99+
],
100+
)
101+
def test_json_valid(input_string: str, expected_value: float):
102+
test = JSONValidityTest()
103+
result = test.get_metric("prompt", "The cat", input_string)
104+
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."

0 commit comments

Comments
 (0)