Skip to content

Commit 58ede2e

Browse files
refactor: refactor RewardEvaluator
1 parent c161358 commit 58ede2e

File tree

6 files changed

+61
-102
lines changed

6 files changed

+61
-102
lines changed

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .base_splitter import BaseSplitter
1010
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
1111
from .base_tokenizer import BaseTokenizer
12+
from .base_evaluator import BaseEvaluator
1213
from .datatypes import Chunk, Config, Node, QAPair, Token

graphgen/bases/base_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from graphgen.bases.datatypes import QAPair
2+
from .datatypes import QAPair
33

44

55
class BaseEvaluator(ABC):

graphgen/models/evaluator/qa/length_evaluator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from graphgen.bases.base_evaluator import BaseEvaluator
2-
from graphgen.bases.datatypes import QAPair
1+
from graphgen.bases import BaseEvaluator, QAPair
32
from graphgen.models.tokenizer import Tokenizer
43

54

graphgen/models/evaluator/qa/mtld_evaluator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Set
22

3-
from graphgen.bases.base_evaluator import BaseEvaluator
4-
from graphgen.bases.datatypes import QAPair
3+
from graphgen.bases import BaseEvaluator, QAPair
54
from graphgen.utils import NLTKHelper, detect_main_language
65

76

Lines changed: 51 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,64 @@
1-
from dataclasses import dataclass
1+
from typing import Optional
2+
from graphgen.bases import BaseEvaluator, QAPair
23

3-
from tqdm import tqdm
44

5-
from graphgen.bases.datatypes import QAPair
6-
7-
8-
@dataclass
9-
class RewardEvaluator:
5+
class RewardEvaluator(BaseEvaluator):
106
"""
11-
Reward Model Evaluator.
12-
OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
7+
Reward Model Evaluator for single QAPair evaluation.
138
"""
149

15-
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
16-
max_length: int = 2560
17-
results: list[float] = None
18-
19-
def __post_init__(self):
20-
import torch
21-
22-
self.num_gpus = torch.cuda.device_count()
10+
def __init__(
11+
self,
12+
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
13+
max_length: int = 2560,
14+
device: Optional[str] = None,
15+
):
16+
"""
17+
Initialize the reward evaluator.
18+
19+
Args:
20+
reward_name: Model name or path on HuggingFace Hub
21+
max_length: Maximum token length for the model
22+
device: Device to run the model on. If None, auto-detect CUDA/CPU.
23+
"""
24+
self.reward_name = reward_name
25+
self.max_length = max_length
2326

24-
@staticmethod
25-
def process_chunk(rank, pairs, reward_name, max_length, return_dict):
2627
import torch
2728
from transformers import AutoModelForSequenceClassification, AutoTokenizer
2829

29-
device = f"cuda:{rank}"
30-
torch.cuda.set_device(rank)
31-
32-
rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
33-
tokenizer = AutoTokenizer.from_pretrained(reward_name)
34-
rank_model.to(device)
35-
rank_model.eval()
36-
37-
results = []
38-
with torch.no_grad():
39-
for pair in tqdm(pairs):
40-
inputs = tokenizer(
41-
pair.question,
42-
pair.answer,
43-
return_tensors="pt",
44-
max_length=max_length,
45-
truncation=True,
46-
)
47-
inputs = {k: v.to(device) for k, v in inputs.items()}
48-
score = rank_model(**inputs).logits[0].item()
49-
results.append(score)
50-
51-
return_dict[rank] = results
30+
# Set device (auto-detect if not specified)
31+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
5232

53-
def evaluate(self, pairs: list[QAPair]) -> list[float]:
54-
import torch.multiprocessing as mp
55-
56-
chunk_size = len(pairs) // self.num_gpus
57-
chunks = []
58-
for i in range(self.num_gpus):
59-
start = i * chunk_size
60-
end = start + chunk_size
61-
if i == self.num_gpus - 1:
62-
end = len(pairs)
63-
chunks.append(pairs[start:end])
64-
65-
# multi-process
66-
manager = mp.Manager()
67-
return_dict = manager.dict()
68-
processes = []
69-
70-
for rank, chunk in enumerate(chunks):
71-
p = mp.Process(
72-
target=self.process_chunk,
73-
args=(rank, chunk, self.reward_name, self.max_length, return_dict),
74-
)
75-
p.start()
76-
processes.append(p)
77-
78-
for p in processes:
79-
p.join()
80-
81-
# 合并结果
82-
results = []
83-
for rank in range(len(chunks)):
84-
results.extend(return_dict[rank])
85-
86-
for p in processes:
87-
if p.is_alive():
88-
p.terminate()
89-
p.join()
90-
91-
return results
92-
93-
def get_average_score(self, pairs: list[QAPair]) -> float:
94-
"""
95-
Get the average score of a batch of texts.
96-
"""
97-
results = self.evaluate(pairs)
98-
self.results = results
99-
return sum(self.results) / len(pairs)
33+
try:
34+
self.tokenizer = AutoTokenizer.from_pretrained(reward_name)
35+
self.model = AutoModelForSequenceClassification.from_pretrained(reward_name)
36+
self.model.to(self.device)
37+
self.model.eval()
38+
except Exception as e:
39+
raise RuntimeError(f"Failed to load reward model '{reward_name}': {e}") from e
10040

101-
def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
41+
def evaluate(self, pair: QAPair) -> float:
10242
"""
103-
Get the min and max score of a batch of texts.
43+
Evaluate a single question-answer pair using the reward model.
44+
45+
Args:
46+
pair: QAPair containing question and answer strings
47+
48+
Returns:
49+
Score as a float
10450
"""
105-
if self.results is None:
106-
self.get_average_score(pairs)
107-
return min(self.results), max(self.results)
51+
# Tokenize
52+
inputs = self.tokenizer(
53+
pair.question,
54+
pair.answer,
55+
return_tensors="pt",
56+
max_length=self.max_length,
57+
truncation=True,
58+
)
59+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
60+
61+
# Get score
62+
score = self.model(**inputs).logits[0].item()
63+
64+
return score

graphgen/models/evaluator/qa/uni_evaluator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# https://github.com/maszhongming/UniEval/tree/main
22

3-
from dataclasses import dataclass, field
3+
from dataclasses import field
44

55
from tqdm import tqdm
66

7-
from graphgen.bases.datatypes import QAPair
7+
from graphgen.bases import BaseEvaluator, QAPair
88

99

1010
def _add_questions(dimension: str, question: str, answer: str):
@@ -32,8 +32,11 @@ def _add_questions(dimension: str, question: str, answer: str):
3232
return cur_input
3333

3434

35-
@dataclass
35+
3636
class UniEvaluator:
37+
"""
38+
UniEvaluator class
39+
"""
3740
model_name: str = "MingZhong/unieval-sum"
3841
dimensions: list = field(
3942
default_factory=lambda: ["naturalness", "coherence", "understandability"]

0 commit comments

Comments
 (0)