99from rouge_score import rouge_scorer
1010from transformers import DistilBertModel , DistilBertTokenizer
1111
12- from llmtune .qa .generics import LLMQaTest , TestRegistry
12+ from llmtune .qa .generics import LLMQaTest , QaTestRegistry
1313
1414
1515model_name = "distilbert-base-uncased"
2121nltk .download ("averaged_perceptron_tagger" )
2222
2323
24- @TestRegistry .register ("summary_length" )
24+ @QaTestRegistry .register ("summary_length" )
2525class LengthTest (LLMQaTest ):
2626 @property
2727 def test_name (self ) -> str :
@@ -31,7 +31,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
3131 return abs (len (ground_truth ) - len (model_prediction ))
3232
3333
34- @TestRegistry .register ("jaccard_similarity" )
34+ @QaTestRegistry .register ("jaccard_similarity" )
3535class JaccardSimilarityTest (LLMQaTest ):
3636 @property
3737 def test_name (self ) -> str :
@@ -48,7 +48,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
4848 return similarity
4949
5050
51- @TestRegistry .register ("dot_product" )
51+ @QaTestRegistry .register ("dot_product" )
5252class DotProductSimilarityTest (LLMQaTest ):
5353 @property
5454 def test_name (self ) -> str :
@@ -67,7 +67,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
6767 return dot_product_similarity
6868
6969
70- @TestRegistry .register ("rouge_score" )
70+ @QaTestRegistry .register ("rouge_score" )
7171class RougeScoreTest (LLMQaTest ):
7272 @property
7373 def test_name (self ) -> str :
@@ -79,7 +79,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
7979 return float (scores ["rouge1" ].precision )
8080
8181
82- @TestRegistry .register ("word_overlap" )
82+ @QaTestRegistry .register ("word_overlap" )
8383class WordOverlapTest (LLMQaTest ):
8484 @property
8585 def test_name (self ) -> str :
@@ -103,6 +103,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
103103 return overlap_percentage
104104
105105
106+ @QaTestRegistry .register ("verb_percent" )
106107class PosCompositionTest (LLMQaTest ):
107108 def _get_pos_percent (self , text : str , pos_tags : List [str ]) -> float :
108109 words = word_tokenize (text )
@@ -112,7 +113,7 @@ def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
112113 return round (len (pos_words ) / total_words , 2 )
113114
114115
115- @TestRegistry .register ("verb_percent" )
116+ @QaTestRegistry .register ("verb_percent" )
116117class VerbPercent (PosCompositionTest ):
117118 @property
118119 def test_name (self ) -> str :
@@ -122,7 +123,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
122123 return self ._get_pos_percent (model_prediction , ["VB" , "VBD" , "VBG" , "VBN" , "VBP" , "VBZ" ])
123124
124125
125- @TestRegistry .register ("adjective_percent" )
126+ @QaTestRegistry .register ("adjective_percent" )
126127class AdjectivePercent (PosCompositionTest ):
127128 @property
128129 def test_name (self ) -> str :
@@ -132,7 +133,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> f
132133 return self ._get_pos_percent (model_prediction , ["JJ" , "JJR" , "JJS" ])
133134
134135
135- @TestRegistry .register ("noun_percent" )
136+ @QaTestRegistry .register ("noun_percent" )
136137class NounPercent (PosCompositionTest ):
137138 @property
138139 def test_name (self ) -> str :
0 commit comments