Skip to content

Commit ab4ba8d

Browse files
committed
Add jdocqa overall score and validate metrics argument
1 parent 461c64e commit ab4ba8d

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

examples/sample.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,27 @@
3131
"--metrics",
3232
type=str,
3333
default="llm_as_a_judge_heron_bench",
34-
help="metrics to evaluate. You can specify multiple metrics separated by comma (e.g. --metrics exact_match,rougel).",
34+
help="metrics to evaluate. You can specify multiple metrics separated by comma (e.g. --metrics exact_match,llm_as_a_judge) You can use rougel,substring_match,jmmmu,jdocqa,llm_as_a_judge_heron_bench,exact_match",
3535
)
3636

37+
valid_metrics = [
38+
"rougel",
39+
"substring_match",
40+
"jmmmu",
41+
"jdocqa",
42+
"llm_as_a_judge_heron_bench",
43+
"exact_match",
44+
]
45+
46+
47+
def validate_metrics(metrics: list[str]):
48+
for metric in metrics:
49+
if metric not in valid_metrics:
50+
raise ValueError(
51+
f"Invalid metric: {metric}. Valid metrics are {valid_metrics}"
52+
)
53+
54+
3755
args = parser.parse_args()
3856

3957
gen_kwargs = GenerationConfig(
@@ -105,6 +123,7 @@
105123
# evaluate the predictions
106124

107125
metrics = args.metrics.split(",")
126+
validate_metrics(metrics)
108127

109128
scores_for_each_metric = {}
110129

src/eval_mm/metrics/jdocqa_scorer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def jdocqa_normalize(text):
2525

2626

2727
def bleu_ja(refs, pred):
28+
"""Calculate BLEU score for Japanese text. Score is normalized to [0, 1]."""
2829
bleu_score = sentence_bleu(
2930
hypothesis=pred,
3031
references=refs,
@@ -34,7 +35,7 @@ def bleu_ja(refs, pred):
3435
use_effective_order=False,
3536
lowercase=False,
3637
)
37-
return bleu_score.score
38+
return bleu_score.score / 100
3839

3940

4041
class JDocQAScorer(Scorer):
@@ -83,6 +84,7 @@ def aggregate(scores: list[int], **kwargs) -> dict:
8384
metrics[key] = 0
8485
continue
8586
metrics[key] = sum(value) / len(value)
87+
metrics["overall"] = sum(scores) / len(scores)
8688

8789
return metrics
8890

0 commit comments

Comments
 (0)