Skip to content

Commit be13c17

Browse files
committed
fix qa test return type
1 parent bfc52d1 commit be13c17

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

llmtune/qa/qa_tests.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
4545
union_size = len(set_ground_truth.union(set_model_prediction))
4646

4747
similarity = intersection_size / union_size if union_size != 0 else 0
48-
return similarity
48+
return float(similarity)
4949

5050

5151
@QaTestRegistry.register("dot_product")
@@ -64,7 +64,7 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
6464
embedding_ground_truth = self._encode_sentence(ground_truth)
6565
embedding_model_prediction = self._encode_sentence(model_prediction)
6666
dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction)
67-
return dot_product_similarity
67+
return float(dot_product_similarity)
6868

6969

7070
@QaTestRegistry.register("rouge_score")
@@ -100,10 +100,9 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
100100

101101
common_words = words_model_prediction.intersection(words_ground_truth)
102102
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
103-
return overlap_percentage
103+
return float(overlap_percentage)
104104

105105

106-
@QaTestRegistry.register("verb_percent")
107106
class PosCompositionTest(LLMQaTest):
108107
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
109108
words = word_tokenize(text)

0 commit comments

Comments
 (0)