33import nltk
44import numpy as np
55import torch
6+ from langchain .evaluation import JsonValidityEvaluator
67from nltk import pos_tag
78from nltk .corpus import stopwords
89from nltk .tokenize import word_tokenize
910from rouge_score import rouge_scorer
1011from transformers import DistilBertModel , DistilBertTokenizer
11- from langchain .evaluation import JsonValidityEvaluator
1212
1313from llmtune .qa .generics import LLMQaTest
1414
@@ -121,13 +121,15 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
121121 overlap_percentage = (len (common_words ) / len (words_ground_truth )) * 100
122122 return float (overlap_percentage )
123123
124+
124125@QaTestRegistry .register ("json_valid" )
125126class JSONValidityTest (LLMQaTest ):
126127 """
127128 Checks to see if valid json can be parsed from the model output, according
128129 to langchain_core.utils.json.parse_json_markdown
129130 The JSON can be wrapped in markdown and this test will still pass
130131 """
132+
131133 @property
132134 def test_name (self ) -> str :
133135 return "json_valid"
@@ -137,6 +139,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
137139 binary_res = result ["score" ]
138140 return float (binary_res )
139141
142+
140143class PosCompositionTest (LLMQaTest ):
141144 def _get_pos_percent (self , text : str , pos_tags : List [str ]) -> float :
142145 words = word_tokenize (text )
0 commit comments