99from rouge_score import rouge_scorer
1010from transformers import DistilBertModel , DistilBertTokenizer
1111
12- from llmtune .qa .generics import LLMQaTest , TestRegistry
12+ from llmtune .qa .generics import LLMQaTest , QaTestRegistry
1313
1414
1515model_name = "distilbert-base-uncased"
2121nltk .download ("averaged_perceptron_tagger" )
2222
2323
24- @TestRegistry .register ("summary_length" )
24+ @QaTestRegistry .register ("summary_length" )
2525class LengthTest (LLMQaTest ):
2626 @property
2727 def test_name (self ) -> str :
@@ -31,7 +31,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
3131 return abs (len (ground_truth ) - len (model_prediction ))
3232
3333
34- @TestRegistry .register ("jaccard_similarity" )
34+ @QaTestRegistry .register ("jaccard_similarity" )
3535class JaccardSimilarityTest (LLMQaTest ):
3636 @property
3737 def test_name (self ) -> str :
@@ -45,10 +45,10 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
4545 union_size = len (set_ground_truth .union (set_model_prediction ))
4646
4747 similarity = intersection_size / union_size if union_size != 0 else 0
48- return similarity
48+ return float ( similarity )
4949
5050
51- @TestRegistry .register ("dot_product" )
51+ @QaTestRegistry .register ("dot_product" )
5252class DotProductSimilarityTest (LLMQaTest ):
5353 @property
5454 def test_name (self ) -> str :
@@ -64,10 +64,10 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
6464 embedding_ground_truth = self ._encode_sentence (ground_truth )
6565 embedding_model_prediction = self ._encode_sentence (model_prediction )
6666 dot_product_similarity = np .dot (embedding_ground_truth , embedding_model_prediction )
67- return dot_product_similarity
67+ return float ( dot_product_similarity )
6868
6969
70- @TestRegistry .register ("rouge_score" )
70+ @QaTestRegistry .register ("rouge_score" )
7171class RougeScoreTest (LLMQaTest ):
7272 @property
7373 def test_name (self ) -> str :
@@ -79,7 +79,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
7979 return float (scores ["rouge1" ].precision )
8080
8181
82- @TestRegistry .register ("word_overlap" )
82+ @QaTestRegistry .register ("word_overlap" )
8383class WordOverlapTest (LLMQaTest ):
8484 @property
8585 def test_name (self ) -> str :
@@ -100,7 +100,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
100100
101101 common_words = words_model_prediction .intersection (words_ground_truth )
102102 overlap_percentage = (len (common_words ) / len (words_ground_truth )) * 100
103- return overlap_percentage
103+ return float ( overlap_percentage )
104104
105105
106106class PosCompositionTest (LLMQaTest ):
@@ -112,7 +112,7 @@ def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
112112 return round (len (pos_words ) / total_words , 2 )
113113
114114
115- @TestRegistry .register ("verb_percent" )
115+ @QaTestRegistry .register ("verb_percent" )
116116class VerbPercent (PosCompositionTest ):
117117 @property
118118 def test_name (self ) -> str :
@@ -122,7 +122,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
122122 return self ._get_pos_percent (model_prediction , ["VB" , "VBD" , "VBG" , "VBN" , "VBP" , "VBZ" ])
123123
124124
125- @TestRegistry .register ("adjective_percent" )
125+ @QaTestRegistry .register ("adjective_percent" )
126126class AdjectivePercent (PosCompositionTest ):
127127 @property
128128 def test_name (self ) -> str :
@@ -132,7 +132,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
132132 return self ._get_pos_percent (model_prediction , ["JJ" , "JJR" , "JJS" ])
133133
134134
135- @TestRegistry .register ("noun_percent" )
135+ @QaTestRegistry .register ("noun_percent" )
136136class NounPercent (PosCompositionTest ):
137137 @property
138138 def test_name (self ) -> str :
0 commit comments