Skip to content

Commit b188f8c

Browse files
authored
Add visualization script (#143)
1 parent 78cd8cf commit b188f8c

File tree

4 files changed

+116
-0
lines changed

4 files changed

+116
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ This tool automatically evaluates Japanese multi-modal large language models acr
2121
- [Supported Tasks](#supported-tasks)
2222
- [Required Libraries for Each VLM Model Inference](#required-libraries-for-each-vlm-model-inference)
2323
- [Benchmark-Specific Required Libraries](#benchmark-specific-required-libraries)
24+
- [Analyze VLMs Prediction](#analyze-vlms-prediction)
2425
- [License](#license)
2526
- [Contribution](#contribution)
2627
- [How to Add a Benchmark Task](#how-to-add-a-benchmark-task)
@@ -139,6 +140,16 @@ JIC-VQA only provide the image URL, so you need to download the images from the
139140
python scripts/prepare_jic_vqa.py
140141
```
141142

143+
## Analyze VLMs Prediction
144+
145+
Let's analyze VLMs prediction!
146+
```bash
147+
uv run streamlit run scripts/browse_prediction.py --task_id "japanese-heron-bench" --result_dir "result"
148+
```
149+
You can see the visualization like below.
150+
![Streamlit](./assets/streamlit_visualization.png)
151+
152+
142153
## License
143154

144155
This repository is licensed under the Apache-2.0 License.

assets/streamlit_visualization.png

885 KB
Loading

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ dev = [
5959
"mypy>=1.15.0",
6060
"pytest>=8.3.4",
6161
"seaborn>=0.13.2",
62+
"streamlit>=1.43.2",
6263
]
6364

6465
evovlm = [

scripts/browse_prediction.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import streamlit as st
2+
from datasets import load_dataset
3+
import random
4+
import eval_mm
5+
from argparse import ArgumentParser
6+
import os
7+
import json
8+
9+
10+
def parse_args():
11+
parser = ArgumentParser()
12+
parser.add_argument("--task_id", type=str, default="japanese-heron-bench")
13+
parser.add_argument("--result_dir", type=str, default="result")
14+
15+
return parser.parse_args()
16+
17+
18+
def scrollable_text(text):
19+
return (
20+
f'<div style="max-height: 300px; overflow-y: auto; height: auto;">{text}</div>'
21+
)
22+
23+
24+
if __name__ == "__main__":
25+
args = parse_args()
26+
27+
task = eval_mm.tasks.TaskRegistry().get_task_cls(args.task_id)(
28+
eval_mm.tasks.TaskConfig()
29+
)
30+
31+
# Load model prediction
32+
model_list = [
33+
"google/gemma-3-12b-it",
34+
"google/gemma-3-27b-it",
35+
"microsoft/Phi-4-multimodal-instruct",
36+
]
37+
predictions_per_model = {}
38+
for model_id in model_list:
39+
prediction_path = os.path.join(
40+
args.result_dir, args.task_id, model_id, "prediction.jsonl"
41+
)
42+
with open(prediction_path, "r") as f:
43+
predictions_per_model[model_id] = [json.loads(line) for line in f]
44+
45+
# VQAデータ読み込み
46+
ds = task.dataset
47+
# session_stateの初期化
48+
st.set_page_config(layout="wide")
49+
if "page" not in st.session_state:
50+
st.session_state.page = 0 # 現在のページ番号
51+
52+
SAMPLES_PER_PAGE = 30 # 1ページに表示する件数
53+
# Question ID, Image, Question, Answer, Prediction_model1, Prediction_model2,..
54+
column_width_list = [1, 3, 3, 3] + [4] * len(model_list)
55+
st.write(f"# {args.task_id} dataset")
56+
57+
def show_sample(idx):
58+
sample = ds[idx]
59+
cols = st.columns(column_width_list)
60+
cols[0].markdown(task.doc_to_id(sample))
61+
cols[1].image(task.doc_to_visual(sample)[0], width=300)
62+
cols[2].markdown(
63+
scrollable_text(task.doc_to_text(sample)), unsafe_allow_html=True
64+
)
65+
cols[3].markdown(
66+
scrollable_text(task.doc_to_answer(sample)), unsafe_allow_html=True
67+
)
68+
for model_id in model_list:
69+
cols[4 + model_list.index(model_id)].markdown(
70+
scrollable_text(predictions_per_model[model_id][idx]["text"]),
71+
unsafe_allow_html=True,
72+
)
73+
74+
# ナビゲーションボタン
75+
nav_col1, nav_col2, nav_col3 = st.columns(3)
76+
if nav_col1.button(f"Prev {SAMPLES_PER_PAGE}"):
77+
st.session_state.page = max(st.session_state.page - 1, 0)
78+
if nav_col2.button("Random"):
79+
st.session_state.page = random.randint(0, len(ds) // SAMPLES_PER_PAGE)
80+
if nav_col3.button(f"Next {SAMPLES_PER_PAGE}"):
81+
st.session_state.page = min(
82+
st.session_state.page + 1, len(ds) // SAMPLES_PER_PAGE
83+
)
84+
85+
# 現在のページのサンプルを表示
86+
start_idx = st.session_state.page * SAMPLES_PER_PAGE
87+
end_idx = min(start_idx + SAMPLES_PER_PAGE, len(ds))
88+
89+
st.write(f"### Showing samples {start_idx + 1} to {end_idx} / {len(ds)}")
90+
91+
# ヘッダー columnを表示
92+
header_cols = st.columns(column_width_list)
93+
header_cols[0].markdown("ID")
94+
header_cols[1].markdown("Image")
95+
header_cols[2].markdown("Question")
96+
header_cols[3].markdown("Answer")
97+
for model_id in model_list:
98+
header_cols[4 + model_list.index(model_id)].markdown(f"Prediction ({model_id})")
99+
100+
# サンプルを表示
101+
for idx in range(start_idx, end_idx):
102+
with st.container():
103+
show_sample(idx)
104+
st.markdown("---")

0 commit comments

Comments
 (0)