Skip to content

Commit f3a0391

Browse files
refactor: refactor UniEvaluator
1 parent 58ede2e commit f3a0391

File tree

2 files changed

+91
-170
lines changed

2 files changed

+91
-170
lines changed

graphgen/models/evaluator/qa/reward_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626

2727
import torch
2828
from transformers import AutoModelForSequenceClassification, AutoTokenizer
29+
self.torch = torch
2930

3031
# Set device (auto-detect if not specified)
3132
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
@@ -59,6 +60,7 @@ def evaluate(self, pair: QAPair) -> float:
5960
inputs = {k: v.to(self.device) for k, v in inputs.items()}
6061

6162
# Get score
62-
score = self.model(**inputs).logits[0].item()
63+
with self.torch.no_grad():
64+
score = self.model(**inputs).logits[0].item()
6365

6466
return score
Lines changed: 88 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,186 +1,105 @@
11
# https://github.com/maszhongming/UniEval/tree/main
2-
3-
from dataclasses import field
4-
5-
from tqdm import tqdm
6-
2+
from typing import Optional, List
73
from graphgen.bases import BaseEvaluator, QAPair
84

95

10-
def _add_questions(dimension: str, question: str, answer: str):
11-
if dimension == "naturalness":
12-
cur_input = (
13-
"question: Is this a natural response in the dialogue? </s> response: "
14-
+ answer
15-
)
16-
elif dimension == "coherence":
17-
cur_input = (
18-
"question: Is this a coherent response given the dialogue history? </s> response: "
19-
+ answer
20-
+ " </s> dialogue history: "
21-
+ question
22-
)
23-
elif dimension == "understandability":
24-
cur_input = (
25-
"question: Is this an understandable response in the dialogue? </s> response: "
26-
+ answer
27-
)
28-
else:
29-
raise NotImplementedError(
30-
"The input format for this dimension is still undefined. Please customize it first."
31-
)
32-
return cur_input
33-
34-
35-
36-
class UniEvaluator:
6+
class UniEvaluator(BaseEvaluator):
377
"""
38-
UniEvaluator class
8+
UniEvaluator for single QAPair evaluation across quality dimensions.
9+
10+
Dimensions: naturalness, coherence, understandability
11+
12+
Usage:
13+
evaluator = UniEvaluator()
14+
pair = QAPair(question="...", answer="...")
15+
scores = evaluator.evaluate(pair)
16+
# {"naturalness": 0.85, "coherence": 0.92, "understandability": 0.88}
3917
"""
40-
model_name: str = "MingZhong/unieval-sum"
41-
dimensions: list = field(
42-
default_factory=lambda: ["naturalness", "coherence", "understandability"]
43-
)
44-
max_length: int = 2560
45-
results: dict = None
46-
47-
def __post_init__(self):
48-
import torch
4918

50-
self.num_gpus = torch.cuda.device_count()
51-
self.results = {}
19+
DEFAULT_MODEL: str = "MingZhong/unieval-sum"
20+
DEFAULT_DIMS: List[str] = ["naturalness", "coherence", "understandability"]
21+
DEFAULT_MAX_LENGTH: int = 2560
5222

53-
@staticmethod
54-
def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
23+
def __init__(
24+
self,
25+
model_name: Optional[str] = None,
26+
max_length: Optional[int] = None,
27+
device: Optional[str] = None,
28+
):
29+
"""
30+
Args:
31+
model_name: HuggingFace model name/path
32+
max_length: Tokenizer max sequence length
33+
device: 'cuda', 'cpu', or None for auto-detect
34+
"""
5535
import torch
5636
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
37+
self.torch = torch
5738

58-
device = f"cuda:{rank}"
59-
torch.cuda.set_device(rank)
60-
61-
rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
62-
tokenizer = AutoTokenizer.from_pretrained(model_name)
63-
rank_model.to(device)
64-
rank_model.eval()
65-
66-
softmax = torch.nn.Softmax(dim=1)
67-
68-
pos_id = tokenizer("Yes")["input_ids"][0]
69-
neg_id = tokenizer("No")["input_ids"][0]
70-
71-
results = []
72-
with torch.no_grad():
73-
for pair in tqdm(pairs):
74-
text = _add_questions(dimension, pair.question, pair.answer)
39+
self.model_name = model_name or self.DEFAULT_MODEL
40+
self.max_length = max_length or self.DEFAULT_MAX_LENGTH
41+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
7542

76-
tgt = "No"
43+
# Load model & tokenizer
44+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
45+
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
46+
self.model.to(self.device)
47+
self.model.eval()
7748

78-
encoded_src = tokenizer(
79-
text,
80-
max_length=max_length,
81-
truncation=True,
82-
padding=True,
83-
return_tensors="pt",
84-
)
85-
encoded_tgt = tokenizer(
86-
tgt,
87-
max_length=max_length,
88-
truncation=True,
89-
padding=True,
90-
return_tensors="pt",
91-
)
49+
# Pre-compute Yes/No token IDs
50+
self._yes_id = self.tokenizer("Yes")["input_ids"][0]
51+
self._no_id = self.tokenizer("No")["input_ids"][0]
9252

93-
src_tokens = encoded_src["input_ids"].to(device)
94-
src_mask = encoded_src["attention_mask"].to(device)
95-
96-
tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1)
97-
98-
output = rank_model(
53+
@staticmethod
54+
def _build_input_text(dimension: str, question: str, answer: str) -> str:
55+
"""Construct input text for specified dimension."""
56+
if dimension == "naturalness":
57+
return f"question: Is this a natural response? </s> response: {answer}"
58+
elif dimension == "coherence":
59+
return f"question: Is this a coherent response? </s> response: {answer} </s> history: {question}"
60+
elif dimension == "understandability":
61+
return f"question: Is this an understandable response? </s> response: {answer}"
62+
raise NotImplementedError(f"Unsupported dimension '{dimension}'")
63+
64+
def evaluate(
65+
self,
66+
pair: QAPair,
67+
dimensions: Optional[List[str]] = None,
68+
) -> dict[str, float]:
69+
"""Evaluate a single QAPair across specified dimensions."""
70+
dimensions = dimensions or self.DEFAULT_DIMS
71+
72+
# Validate dimensions
73+
invalid = set(dimensions) - set(self.DEFAULT_DIMS)
74+
if invalid:
75+
raise ValueError(f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}")
76+
77+
results = {}
78+
no_token = self.torch.tensor([[self._no_id]], device=self.device)
79+
80+
for dim in dimensions:
81+
# Tokenize input
82+
src = self.tokenizer(
83+
self._build_input_text(dim, pair.question, pair.answer),
84+
max_length=self.max_length,
85+
truncation=True,
86+
return_tensors="pt",
87+
)
88+
src_tokens = src["input_ids"].to(self.device)
89+
src_mask = src["attention_mask"].to(self.device)
90+
91+
# Score
92+
with self.torch.no_grad():
93+
logits = self.model(
9994
input_ids=src_tokens,
10095
attention_mask=src_mask,
101-
labels=tgt_tokens,
96+
labels=no_token,
10297
use_cache=False,
103-
)
104-
105-
logits = output.logits.view(-1, rank_model.config.vocab_size)
106-
107-
pos_score = softmax(logits)[:, pos_id] # Yes
108-
neg_score = softmax(logits)[:, neg_id]
109-
score = pos_score / (pos_score + neg_score)
110-
111-
results.append(score.item())
112-
113-
return_dict[rank] = results
114-
115-
def evaluate(self, pairs: list[QAPair]) -> list[dict]:
116-
import torch.multiprocessing as mp
117-
118-
final_results = []
119-
for dimension in self.dimensions:
120-
chunk_size = len(pairs) // self.num_gpus
121-
chunks = []
122-
for i in range(self.num_gpus):
123-
start = i * chunk_size
124-
end = start + chunk_size
125-
if i == self.num_gpus - 1:
126-
end = len(pairs)
127-
chunks.append(pairs[start:end])
128-
129-
# multi-process
130-
manager = mp.Manager()
131-
return_dict = manager.dict()
132-
processes = []
133-
134-
for rank, chunk in enumerate(chunks):
135-
p = mp.Process(
136-
target=self.process_chunk,
137-
args=(
138-
rank,
139-
chunk,
140-
self.model_name,
141-
self.max_length,
142-
dimension,
143-
return_dict,
144-
),
145-
)
146-
p.start()
147-
processes.append(p)
148-
149-
for p in processes:
150-
p.join()
151-
152-
# 合并结果
153-
results = []
154-
for rank in range(len(chunks)):
155-
results.extend(return_dict[rank])
156-
157-
for p in processes:
158-
if p.is_alive():
159-
p.terminate()
160-
p.join()
161-
162-
final_results.append({dimension: results})
163-
return final_results
164-
165-
def get_average_score(self, pairs: list[QAPair]) -> dict:
166-
"""
167-
Get the average score of a batch of texts.
168-
"""
169-
results = self.evaluate(pairs)
170-
final_results = {}
171-
for result in results:
172-
for key, value in result.items():
173-
final_results[key] = sum(value) / len(value)
174-
self.results[key] = value
175-
return final_results
176-
177-
def get_min_max_score(self, pairs: list[QAPair]) -> dict:
178-
"""
179-
Get the min and max score of a batch of texts.
180-
"""
181-
if self.results is None:
182-
self.get_average_score(pairs)
183-
final_results = {}
184-
for key, value in self.results.items():
185-
final_results[key] = min(value), max(value)
186-
return final_results
98+
).logits[:, 0, :] # [1, vocab_size]
99+
100+
probs = self.torch.softmax(logits, dim=-1)[0]
101+
score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id])
102+
103+
results[dim] = score.item()
104+
105+
return results

0 commit comments

Comments
 (0)