Skip to content

Commit 19510d9

Browse files
refator: refactor LengthEvaluator
1 parent 77bb00d commit 19510d9

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed
Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
from graphgen.bases.base_evaluator import BaseEvaluator
22
from graphgen.bases.datatypes import QAPair
33
from graphgen.models.tokenizer import Tokenizer
4-
from graphgen.utils import create_event_loop
54

65

76
class LengthEvaluator(BaseEvaluator):
8-
def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
9-
super().__init__(max_concurrent)
10-
self.tokenizer_name = tokenizer_name
11-
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
7+
def __init__(self, tokenizer: Tokenizer):
8+
self.tokenizer = tokenizer
129

13-
async def evaluate_single(self, pair: QAPair) -> float:
14-
loop = create_event_loop()
15-
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
16-
17-
def _calculate_length(self, text: str) -> float:
18-
tokens = self.tokenizer.encode(text)
10+
def evaluate(self, pair: QAPair) -> float:
11+
"""
12+
Evaluate the length of the qa pair.
13+
"""
14+
content = pair.question + pair.answer
15+
tokens = self.tokenizer.encode(content)
1916
return len(tokens)

0 commit comments

Comments
 (0)