Skip to content

Commit bd80f1f

Browse files
authored
Merge pull request #78 from llm-jp/74-custom_metrics
Add custom metrics features and Refactoring
2 parents e3230ac + 56a19b3 commit bd80f1f

20 files changed

+1477
-1347
lines changed

README.md

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
- [LLM-jp-eval-mm](#llm-jp-eval-mm)
1919
- [目次](#目次)
2020
- [環境構築](#環境構築)
21+
- [PyPIでインストールする](#pypiでインストールする)
22+
- [GitHubをCloneする場合](#githubをcloneする場合)
2123
- [評価方法](#評価方法)
22-
- [サンプルコードの実行](#サンプルコードの実行)
24+
- [評価の実行](#評価の実行)
2325
- [評価結果の確認](#評価結果の確認)
2426
- [リーダーボードの公開](#リーダーボードの公開)
2527
- [サポートするタスク](#サポートするタスク)
28+
- [各VLMモデル推論時の必要ライブラリ情報](#各vlmモデル推論時の必要ライブラリ情報)
2629
- [ライセンス](#ライセンス)
2730
- [Contribution](#contribution)
2831

@@ -98,10 +101,13 @@ rye run bash examples/evaluate.sh
98101
その場合は以下のコマンドを実行してください.
99102

100103
```bash
101-
rye run python3 examples/sample.py \
102-
--class_path llava_1_5 \
103-
--task_id japanese-heron-bench \
104-
--openai_model_id gpt-4o-mini-2024-07-18
104+
python3 examples/sample.py \
105+
--class_path llava_1_5_7b_hf \
106+
--task_id japanese-heron-bench \
107+
--result_dir test \
108+
--metrics "llm_as_a_judge_heron_bench,exact_match,rougel" \
109+
--judge_model "gpt-4o-2024-05-13" \
110+
--overwrite
105111
```
106112

107113
### 評価結果の確認
@@ -135,6 +141,30 @@ rye run python3 scripts/japanese-heron-bench/record_output.py
135141
- JA-Multi-Image-VQA
136142
- JMMMU
137143

144+
## 各VLMモデル推論時の必要ライブラリ情報
145+
146+
- OpenGVLab/InternVL2-8B
147+
148+
OOM防止のためFlashAttentionのInstallが必要です.
149+
```bash
150+
uv pip install flash-attn --no-build-isolation --python .venv
151+
```
152+
153+
- Llama_3_EvoVLM_JP_v2
154+
155+
mantis-vl のインストールが必要です.
156+
```bash
157+
rye add "datasets==2.18.0"
158+
rye add --dev mantis-vl --git=https://github.com/TIGER-AI-Lab/Mantis.git
159+
```
160+
161+
- Qwen/Qwen2-VL-7B-Instruct
162+
163+
qwen-vl-utils のインストールが必要です.
164+
```bash
165+
rye add --dev qwen-vl-utils
166+
```
167+
138168
## ライセンス
139169

140170
各評価データセットのライセンスは[DATASET.md](./DATASET.md)を参照してください.

examples/InternVL2_8B.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,16 @@ def generate(
166166
pixel_values = (
167167
load_image(image, max_num=12).to(self.model.device).to(self.model.dtype)
168168
)
169+
import copy
170+
generation_config = copy.deepcopy(gen_kwargs.__dict__)
171+
generation_config.pop("use_cache")
169172

170173
response = self.model.chat(
171174
self.tokenizer,
172175
pixel_values,
173176
text,
174177
num_patches_list=num_patches_list,
175-
generation_config=gen_kwargs.__dict__,
178+
generation_config=generation_config,
176179
)
177180
generated_text = response
178181
return generated_text

examples/sample.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
parser = argparse.ArgumentParser()
1212
parser.add_argument("--class_path", type=str, default="llava_1_5_7b_hf")
1313
parser.add_argument("--task_id", type=str, default="japanese-heron-bench")
14-
parser.add_argument("--openai_model_id", type=str, default="gpt-4o-mini-2024-07-18")
14+
parser.add_argument("--judge_model", type=str, default="gpt-4o-mini-2024-07-18")
1515
parser.add_argument("--batch_size_for_evaluation", type=int, default=10)
1616
parser.add_argument("--overwrite", action="store_true")
1717
parser.add_argument("--result_dir", type=str, default="result")
@@ -22,6 +22,18 @@
2222
parser.add_argument("--top_p", type=float, default=1.0)
2323
parser.add_argument("--do_sample", action="store_true", default=False)
2424
parser.add_argument("--use_cache", action="store_true", default=True)
25+
parser.add_argument(
26+
"--max_dataset_len",
27+
type=int,
28+
default=None,
29+
help="max data size for evaluation. If None, use all data. Else, use the first n data.",
30+
)
31+
parser.add_argument(
32+
"--metrics",
33+
type=str,
34+
default="llm_as_a_judge_heron_bench",
35+
help="metrics to evaluate. You can specify multiple metrics separated by comma (e.g. --metrics exact_match,rougel).",
36+
)
2537

2638
args = parser.parse_args()
2739

@@ -36,13 +48,16 @@
3648

3749
class_path = args.class_path
3850
task_id = args.task_id
39-
openai_model_id = args.openai_model_id
4051

4152
module = importlib.import_module(class_path)
4253
model_id = module.VLM.model_id.replace("/", "-")
4354

44-
task = eval_mm.api.registry.get_task(task_id)
45-
dataset = task.dataset
55+
task_config = eval_mm.api.task.TaskConfig(
56+
max_dataset_len=args.max_dataset_len,
57+
judge_model=args.judge_model,
58+
batch_size_for_evaluation=args.batch_size_for_evaluation,
59+
)
60+
task = eval_mm.api.registry.get_task_cls(task_id)(task_config)
4661

4762
# save the predictions to jsonl file
4863
os.makedirs(args.result_dir, exist_ok=True)
@@ -57,16 +72,19 @@
5772

5873
prediction_result_file_path = os.path.join(prediction_result_dir, f"{model_id}.jsonl")
5974

60-
6175
# if prediciton is already done, load the prediction
6276
if os.path.exists(prediction_result_file_path) and not args.overwrite:
6377
with open(prediction_result_file_path, "r") as f:
6478
preds = [json.loads(line) for line in f]
79+
assert (
80+
len(preds) == len(task.dataset)
81+
), f"Prediction result length is not equal to the dataset length. Prediction result length: {len(preds)}, Dataset length: {len(task.dataset)}"
6582
print(f"Prediction result loaded from {prediction_result_file_path}")
6683
else:
6784
model = module.VLM()
6885
preds = []
69-
for doc in tqdm(dataset):
86+
print(task.dataset)
87+
for doc in tqdm(task.dataset):
7088
# print("doc", doc)
7189
image = task.doc_to_visual(doc)
7290
text = task.doc_to_text(doc)
@@ -90,20 +108,36 @@
90108
exit()
91109
print("Evaluation start")
92110
# evaluate the predictions
93-
metrics, eval_results = task.compute_metrics(
94-
preds, model_id=openai_model_id, batch_size=args.batch_size_for_evaluation
95-
)
111+
112+
metrics = args.metrics.split(",")
113+
114+
scores_for_each_metric = {}
115+
116+
for metric in metrics:
117+
scores_for_each_metric[metric] = task.calc_scores(preds, metric)
118+
print(f"Scores for {metric}: {scores_for_each_metric[metric]}")
119+
120+
calculated_metrics = {}
121+
122+
for metric in metrics:
123+
calculated_metrics[metric] = task.gather_scores(
124+
scores_for_each_metric[metric], metric
125+
)
126+
print(f"{metric}: {calculated_metrics[metric]}")
96127

97128

98-
results = task.format_result(preds, eval_results)
99129
with open(os.path.join(prediction_result_file_path), "w") as f:
100-
for result in results:
101-
f.write(json.dumps(result, ensure_ascii=False) + "\n")
130+
for i, pred in enumerate(preds):
131+
question_id = pred["question_id"]
132+
text = pred["text"]
133+
answer = task.doc_to_answer(task.dataset[i])
134+
content = {"question_id": question_id, "text": text, "answer": answer}
135+
for metric in metrics:
136+
content[metric] = scores_for_each_metric[metric][i]
137+
f.write(json.dumps(content, ensure_ascii=False) + "\n")
102138
print(f"Prediction result saved to {prediction_result_file_path}")
103139

104140
eval_result_file_path = os.path.join(evaluation_result_dir, f"{model_id}.jsonl")
105141
with open(eval_result_file_path, "w") as f:
106-
f.write(json.dumps(metrics, ensure_ascii=False) + "\n")
107-
108-
print(f"Metrics: {metrics}")
109-
print(f"Evaluation result example: {eval_results[0]}")
142+
f.write(json.dumps(calculated_metrics, ensure_ascii=False) + "\n")
143+
print(f"Evaluation result saved to {eval_result_file_path}")

src/eval_mm/api/registry.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ def decorate(fn):
1818
return decorate
1919

2020

21-
def get_task(task_name):
21+
def get_task_cls(task_name):
2222
try:
2323
task_cls = TASK_REGISTRY[task_name]
24-
task = task_cls()
25-
return task
24+
return task_cls
2625
except KeyError:
2726
raise KeyError(f"Missing task {task_name}")

src/eval_mm/api/task.py

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,15 @@
11
import abc
2-
from collections.abc import Callable
3-
from dataclasses import asdict, dataclass
42

3+
from dataclasses import dataclass
4+
from eval_mm.utils.azure_client import OpenAIChatAPI
5+
from datasets import Dataset
56

6-
@dataclass
7-
class TaskConfig(dict):
8-
def __getitem__(self, item):
9-
return getattr(self, item)
10-
11-
def __setitem__(self, item, value):
12-
return setattr(self, item, value)
137

14-
def to_dict(self):
15-
"""dumps the current config as a dictionary object, as a printable format.
16-
:return: dict
17-
A printable dictionary version of the TaskConfig object.
18-
"""
19-
cfg_dict = asdict(self)
20-
# remove values that are `None`
21-
for k, v in list(cfg_dict.items()):
22-
if v is None:
23-
cfg_dict.pop(k)
24-
elif isinstance(v, Callable):
25-
# TODO: this should handle Promptsource template objects as a separate case?
26-
cfg_dict[k] = str(v)
27-
return cfg_dict
8+
@dataclass
9+
class TaskConfig:
10+
max_dataset_len: int | None = None
11+
judge_model: str = "gpt-4o-mini-2024-07-18"
12+
batch_size_for_evaluation: int = 10
2813

2914

3015
class Task(abc.ABC):
@@ -34,24 +19,21 @@ class Task(abc.ABC):
3419
{"question": ..., "answer": ...} or {"question": ..., question, answer)
3520
"""
3621

37-
def __init__(self, config=None) -> None:
38-
self._config = TaskConfig({**config}) if config else TaskConfig()
22+
def __init__(self, config: TaskConfig):
3923
self._dataset = None
40-
self.prepare_task(config)
41-
42-
@property
43-
def config(self):
44-
"""Returns the TaskConfig associated with this class."""
45-
return self._config
24+
self.client = OpenAIChatAPI()
25+
self.config = config
4626

47-
@property
48-
def dataset(self):
49-
"""Returns the dataset associated with this class."""
50-
return self._dataset
27+
if self.config.max_dataset_len is not None:
28+
self.dataset = self._prepare_dataset().select(
29+
range(self.config.max_dataset_len)
30+
)
31+
else:
32+
self.dataset = self._prepare_dataset()
5133

5234
@abc.abstractmethod
53-
def prepare_task(self, config):
54-
"""Prepares a document for evaluation."""
35+
def _prepare_dataset(self) -> Dataset:
36+
"""Prepares the dataset."""
5537
pass
5638

5739
@abc.abstractmethod
@@ -70,18 +52,16 @@ def doc_to_id(self, doc):
7052
pass
7153

7254
@abc.abstractmethod
73-
def evaluate(self, docs: list, preds: list) -> list[dict]:
74-
"""Evaluate batch prediction."""
55+
def doc_to_answer(self, doc):
56+
"""Converts a document to answer."""
57+
pass
58+
59+
@abc.abstractmethod
60+
def calc_scores(self, preds: list, metric: str) -> list:
61+
"""Calculates scores for the predictions."""
7562
pass
7663

7764
@abc.abstractmethod
78-
def compute_metrics(self, preds):
79-
"""
80-
Args:
81-
doc: a instance of the eval dataset
82-
results: [pred]
83-
Returns:
84-
metrics: a dictionary with key: metric name (in this case coco_bleu), value: metric value
85-
results_verbose: a dictionary with key: metric name, value: a dictionary with key: 'score' and 'verbose'
86-
"""
65+
def gather_scores(self, scores: list[dict], metric: str) -> dict:
66+
"""Aggregates the scores."""
8767
pass

src/eval_mm/metrics/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from .heron_bench_scorer import HeronBenchScorer
2+
from .exact_match_scorer import ExactMatchScorer
3+
from .llm_as_a_judge_scorer import LlmAsaJudgeScorer
4+
from .rougel_scorer import RougeLScorer
5+
from .substring_match_scorer import SubstringMatchScorer
6+
from .scorer import Scorer
7+
from .jmmmu_scorer import JMMMUScorer
8+
9+
10+
class ScorerRegistry:
11+
"""Registry to map metrics to their corresponding scorer classes."""
12+
13+
_scorers = {
14+
"llm_as_a_judge_heron_bench": HeronBenchScorer,
15+
"exact_match": ExactMatchScorer,
16+
"llm_as_a_judge": LlmAsaJudgeScorer,
17+
"rougel": RougeLScorer,
18+
"substring_match": SubstringMatchScorer,
19+
"jmmmu": JMMMUScorer,
20+
}
21+
22+
@classmethod
23+
def register(cls, metric: str, scorer_class: type):
24+
"""Register a new scorer for a metric."""
25+
cls._scorers[metric] = scorer_class
26+
27+
@classmethod
28+
def get_scorer(cls, metric: str) -> Scorer:
29+
"""Get the scorer class for the given metric."""
30+
try:
31+
return cls._scorers[metric]
32+
except KeyError:
33+
raise ValueError(f"Metric '{metric}' is not supported.")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .scorer import Scorer
2+
3+
4+
class ExactMatchScorer(Scorer):
5+
@staticmethod
6+
def score(refs: list[str], preds: list[str], **kwargs) -> list[int]:
7+
scores = [int(ref == pred) for ref, pred in zip(refs, preds)]
8+
return scores
9+
10+
@staticmethod
11+
def aggregate(scores: list[int], **kwargs) -> float:
12+
return sum(scores) / len(scores)

0 commit comments

Comments
 (0)