Skip to content

Commit ab92b69

Browse files
committed
feat: fix spark
1 parent d09737d commit ab92b69

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

dingo/exec/spark.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def _aggregate_eval_details(acc, item):
8989
if field_key not in acc['metric_scores']:
9090
acc['metric_scores'][field_key] = {}
9191

92-
# 遍历 List[EvalDetail]
92+
# 遍历 List[EvalDetail],同时收集指标分数和标签
93+
label_set = set()
9394
for eval_detail in eval_detail_list:
9495
# 收集指标分数(用于RAG等评估场景,按 field_key 分组)
9596
score = eval_detail.get('score') if isinstance(eval_detail, dict) else getattr(eval_detail, 'score', None)
@@ -100,15 +101,18 @@ def _aggregate_eval_details(acc, item):
100101
acc['metric_scores'][field_key][metric] = []
101102
acc['metric_scores'][field_key][metric].append(score)
102103

103-
# 收集标签统计
104+
# 收集标签统计(使用 set 去重,避免同一 item 中重复 label 多次计数)
104105
label_list = eval_detail.get('label', []) if isinstance(eval_detail, dict) else getattr(eval_detail, 'label', [])
105106
if label_list:
106-
# 统计每个 label 的出现次数
107107
for label in label_list:
108-
if label not in acc['label_counts'][field_key]:
109-
acc['label_counts'][field_key][label] = 1
110-
else:
111-
acc['label_counts'][field_key][label] += 1
108+
label_set.add(label)
109+
110+
# 对该 item 的每个唯一 label 计数 +1
111+
for label in label_set:
112+
if label not in acc['label_counts'][field_key]:
113+
acc['label_counts'][field_key][label] = 1
114+
else:
115+
acc['label_counts'][field_key][label] += 1
112116

113117
return acc
114118

@@ -197,14 +201,14 @@ def execute(self) -> SummaryModel:
197201

198202
def evaluate(self, data_rdd_item) -> Dict[str, Any]:
199203
"""Evaluate a single data item using broadcast variables."""
200-
data: Data = data_rdd_item
201-
result_info = ResultInfo(raw_data = data.to_dict())
204+
data: Data = data_rdd_item.asDict()
205+
result_info = ResultInfo(raw_data = data)
202206

203207
for e_p in self.input_args.evaluator:
204208
if e_p.fields:
205-
map_data = {k: data.to_dict().get(v) for k, v in e_p.fields.items()}
209+
map_data = {k: data.get(v) for k, v in e_p.fields.items()}
206210
else:
207-
map_data = data.to_dict()
211+
map_data = data
208212
eval_list_rule = [eval for eval in e_p.evals if eval.name in Model.rule_name_map]
209213
eval_list_llm = [eval for eval in e_p.evals if eval.name in Model.llm_name_map]
210214
for eval_type in ["rule", "llm"]:

0 commit comments

Comments
 (0)