Skip to content

Commit f593fec

Browse files
authored
Add JDocQA task (#79)
1 parent 48fe6c0 commit f593fec

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
- [リーダーボードの公開](#リーダーボードの公開)
2727
- [サポートするタスク](#サポートするタスク)
2828
- [各VLMモデル推論時の必要ライブラリ情報](#各vlmモデル推論時の必要ライブラリ情報)
29+
- [タスク固有の必要ライブラリ情報](#タスク固有の必要ライブラリ情報)
2930
- [ライセンス](#ライセンス)
3031
- [Contribution](#contribution)
3132

@@ -165,6 +166,16 @@ qwen-vl-utils のインストールが必要です.
165166
rye add --dev qwen-vl-utils
166167
```
167168

169+
## タスク固有の必要ライブラリ情報
170+
171+
- JDocQA
172+
173+
```bash
174+
sudo apt-get install poppler-utils
175+
rye add pdf2image
176+
rye add "sacrebleu[ja]"
177+
```
178+
168179
## ライセンス
169180

170181
各評価データセットのライセンスは[DATASET.md](./DATASET.md)を参照してください.

src/eval_mm/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .substring_match_scorer import SubstringMatchScorer
66
from .scorer import Scorer
77
from .jmmmu_scorer import JMMMUScorer
8+
from .jdocqa_scorer import JDocQAScorer
89

910

1011
class ScorerRegistry:
@@ -17,6 +18,7 @@ class ScorerRegistry:
1718
"rougel": RougeLScorer,
1819
"substring_match": SubstringMatchScorer,
1920
"jmmmu": JMMMUScorer,
21+
"jdocqa": JDocQAScorer,
2022
}
2123

2224
@classmethod
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from eval_mm.metrics.scorer import Scorer
2+
from sacrebleu import sentence_bleu
3+
from unicodedata import normalize
4+
5+
ANSWER_TYPE_MAP = {
6+
"yesno": 0, # Yes/No questions
7+
"factoid": 1, # Factoid questions
8+
"numerical": 2, # Numerical questions
9+
"open-ended": 3, # Open-ended questions
10+
}
11+
12+
NUM_TO_ANSWER_TYPE = {v: k for k, v in ANSWER_TYPE_MAP.items()}
13+
14+
15+
def jdocqa_normalize(text):
16+
text = (
17+
text.replace("です", "")
18+
.replace("。", "")
19+
.replace("、", "")
20+
.replace(" ", "")
21+
.strip()
22+
)
23+
text = normalize("NFKC", text)
24+
return text
25+
26+
27+
def bleu_ja(refs, pred):
28+
bleu_score = sentence_bleu(
29+
hypothesis=pred,
30+
references=refs,
31+
smooth_method="exp",
32+
smooth_value=0.0,
33+
tokenize="ja-mecab",
34+
use_effective_order=False,
35+
lowercase=False,
36+
)
37+
return bleu_score.score
38+
39+
40+
class JDocQAScorer(Scorer):
41+
@staticmethod
42+
def score(refs: list[str], preds: list[str], **kwargs) -> list[int]:
43+
docs = kwargs["docs"]
44+
scores = []
45+
46+
for doc, ref, pred in zip(docs, refs, preds):
47+
if doc["answer_type"] == ANSWER_TYPE_MAP["open-ended"]:
48+
scores.append(bleu_ja([ref], pred))
49+
elif doc["answer_type"] in [
50+
ANSWER_TYPE_MAP["yesno"],
51+
ANSWER_TYPE_MAP["factoid"],
52+
ANSWER_TYPE_MAP["numerical"],
53+
]:
54+
ref = jdocqa_normalize(ref)
55+
pred = jdocqa_normalize(pred)
56+
if ref in pred:
57+
scores.append(1)
58+
else:
59+
scores.append(0)
60+
else:
61+
raise NotImplementedError("Bad answer type.")
62+
63+
return scores
64+
65+
@staticmethod
66+
def aggregate(scores: list[int], **kwargs) -> dict:
67+
docs = kwargs["docs"]
68+
metrics = {
69+
"yesno_exact": [],
70+
"factoid_exact": [],
71+
"numerical_exact": [],
72+
"open-ended_bleu": [],
73+
}
74+
for doc, score in zip(docs, scores):
75+
answer_type = doc["answer_type"]
76+
if answer_type == ANSWER_TYPE_MAP["open-ended"]:
77+
metrics["open-ended_bleu"].append(score)
78+
else:
79+
metrics[f"{NUM_TO_ANSWER_TYPE[answer_type]}_exact"].append(score)
80+
81+
for key, value in metrics.items():
82+
if len(value) == 0:
83+
metrics[key] = 0
84+
continue
85+
metrics[key] = sum(value) / len(value)
86+
87+
return metrics
88+
89+
90+
if __name__ == "__main__":
91+
from datasets import load_dataset
92+
93+
ds = load_dataset("shunk031/JDocQA", split="test")
94+
ds = ds.select(range(10))
95+
96+
ref = ds["answer"][0]
97+
pred = ds["answer"][0]
98+
print(ref)
99+
print(pred)
100+
print(bleu_ja([ref], pred))
101+
answer_types = ds["answer_type"]
102+
answers = ds["answer"]
103+
print("Original answers")
104+
for answer_type, answer in zip(answer_types, answers):
105+
print(NUM_TO_ANSWER_TYPE[answer_type], answer)
106+
107+
print("JDocQA normalized answers")
108+
jdocqa_normalize_answers = [jdocqa_normalize(x) for x in ds["answer"]]
109+
for answer_type, answer in zip(answer_types, jdocqa_normalize_answers):
110+
print(NUM_TO_ANSWER_TYPE[answer_type], answer)
111+
112+
scores = JDocQAScorer.score(refs=ds["answer"], preds=ds["answer"], docs=ds)
113+
print(scores)
114+
metrics = JDocQAScorer.aggregate(scores, docs=ds)
115+
print(metrics)

src/eval_mm/tasks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .ja_vlm_bench_in_the_wild import JaVLMBenchIntheWild
44
from .jmmmu import JMMMU
55
from .ja_multi_image_vqa import JAMultiImageVQA
6+
from .jdocqa import JDocQA

src/eval_mm/tasks/jdocqa.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from datasets import Dataset, load_dataset
2+
from pdf2image import convert_from_path
3+
4+
from ..api.registry import register_task
5+
from ..api.task import Task
6+
from eval_mm.metrics import ScorerRegistry
7+
8+
import aiohttp
9+
from PIL import Image
10+
11+
Image.MAX_IMAGE_PIXELS = None
12+
13+
14+
def pdf_to_images(pdf_path):
15+
images = convert_from_path(pdf_path)
16+
return images
17+
18+
19+
def get_elements_from_index(indices_str, array):
20+
try:
21+
indices = [int(x.strip()) - 1 for x in indices_str.split(",")]
22+
elements = [array[i] for i in indices if 0 <= i < len(array)]
23+
return elements
24+
except ValueError:
25+
print("The string doesn't seem to have numbers or commas in the right places.")
26+
return None # Or maybe an empty list, depending on how you wanna handle it
27+
except IndexError:
28+
print("Out of bounds error!")
29+
return None # Same, an empty list or special value could work
30+
31+
32+
@register_task("jdocqa")
33+
class JDocQA(Task):
34+
@staticmethod
35+
def _prepare_dataset() -> Dataset:
36+
ds = load_dataset(
37+
"shunk031/JDocQA",
38+
split="test",
39+
rename_pdf_category=True,
40+
trust_remote_code=True,
41+
storage_options={
42+
"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=3600)}
43+
},
44+
)
45+
ds = ds.rename_column("question", "input_text")
46+
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
47+
keep_columns = [
48+
"input_text",
49+
"pdf_filepath",
50+
"question_page_number",
51+
"question_id",
52+
"answer",
53+
"answer_type",
54+
]
55+
ds = ds.remove_columns(
56+
[col for col in ds.column_names if col not in keep_columns]
57+
)
58+
return ds
59+
60+
@staticmethod
61+
def doc_to_text(doc):
62+
return doc["input_text"]
63+
64+
@staticmethod
65+
def doc_to_visual(doc):
66+
images_all = pdf_to_images(doc["pdf_filepath"])
67+
images = get_elements_from_index(doc["question_page_number"], images_all)
68+
return images
69+
70+
@staticmethod
71+
def doc_to_id(doc):
72+
return doc["question_id"]
73+
74+
@staticmethod
75+
def doc_to_answer(doc):
76+
return doc["answer"]
77+
78+
def calc_scores(self, preds: list, metric: str) -> list:
79+
"""Calculate scores of each prediction based on the metric."""
80+
docs = self.dataset
81+
refs = [doc["answer"] for doc in docs]
82+
pred_texts = [pred["text"] for pred in preds]
83+
scorer = ScorerRegistry.get_scorer(metric)
84+
kwargs = {
85+
"docs": docs,
86+
"client": self.client,
87+
"judge_model": self.config.judge_model,
88+
"batch_size": self.config.batch_size_for_evaluation,
89+
}
90+
return scorer.score(refs, pred_texts, **kwargs)
91+
92+
def gather_scores(self, scores: list[dict], metric: str) -> dict:
93+
kwargs = {"docs": self.dataset}
94+
scorer = ScorerRegistry.get_scorer(metric)
95+
return scorer.aggregate(scores, **kwargs)

0 commit comments

Comments
 (0)