|
| 1 | +import streamlit as st |
| 2 | +from datasets import load_dataset |
| 3 | +import random |
| 4 | +import eval_mm |
| 5 | +from argparse import ArgumentParser |
| 6 | +import os |
| 7 | +import json |
| 8 | + |
| 9 | + |
| 10 | +def parse_args(): |
| 11 | + parser = ArgumentParser() |
| 12 | + parser.add_argument("--task_id", type=str, default="japanese-heron-bench") |
| 13 | + parser.add_argument("--result_dir", type=str, default="result") |
| 14 | + |
| 15 | + return parser.parse_args() |
| 16 | + |
| 17 | + |
| 18 | +def scrollable_text(text): |
| 19 | + return ( |
| 20 | + f'<div style="max-height: 300px; overflow-y: auto; height: auto;">{text}</div>' |
| 21 | + ) |
| 22 | + |
| 23 | + |
| 24 | +if __name__ == "__main__": |
| 25 | + args = parse_args() |
| 26 | + |
| 27 | + task = eval_mm.tasks.TaskRegistry().get_task_cls(args.task_id)( |
| 28 | + eval_mm.tasks.TaskConfig() |
| 29 | + ) |
| 30 | + |
| 31 | + # Load model prediction |
| 32 | + model_list = [ |
| 33 | + "google/gemma-3-12b-it", |
| 34 | + "google/gemma-3-27b-it", |
| 35 | + "microsoft/Phi-4-multimodal-instruct", |
| 36 | + ] |
| 37 | + predictions_per_model = {} |
| 38 | + for model_id in model_list: |
| 39 | + prediction_path = os.path.join( |
| 40 | + args.result_dir, args.task_id, model_id, "prediction.jsonl" |
| 41 | + ) |
| 42 | + with open(prediction_path, "r") as f: |
| 43 | + predictions_per_model[model_id] = [json.loads(line) for line in f] |
| 44 | + |
| 45 | + # VQAデータ読み込み |
| 46 | + ds = task.dataset |
| 47 | + # session_stateの初期化 |
| 48 | + st.set_page_config(layout="wide") |
| 49 | + if "page" not in st.session_state: |
| 50 | + st.session_state.page = 0 # 現在のページ番号 |
| 51 | + |
| 52 | + SAMPLES_PER_PAGE = 30 # 1ページに表示する件数 |
| 53 | + # Question ID, Image, Question, Answer, Prediction_model1, Prediction_model2,.. |
| 54 | + column_width_list = [1, 3, 3, 3] + [4] * len(model_list) |
| 55 | + st.write(f"# {args.task_id} dataset") |
| 56 | + |
| 57 | + def show_sample(idx): |
| 58 | + sample = ds[idx] |
| 59 | + cols = st.columns(column_width_list) |
| 60 | + cols[0].markdown(task.doc_to_id(sample)) |
| 61 | + cols[1].image(task.doc_to_visual(sample)[0], width=300) |
| 62 | + cols[2].markdown( |
| 63 | + scrollable_text(task.doc_to_text(sample)), unsafe_allow_html=True |
| 64 | + ) |
| 65 | + cols[3].markdown( |
| 66 | + scrollable_text(task.doc_to_answer(sample)), unsafe_allow_html=True |
| 67 | + ) |
| 68 | + for model_id in model_list: |
| 69 | + cols[4 + model_list.index(model_id)].markdown( |
| 70 | + scrollable_text(predictions_per_model[model_id][idx]["text"]), |
| 71 | + unsafe_allow_html=True, |
| 72 | + ) |
| 73 | + |
| 74 | + # ナビゲーションボタン |
| 75 | + nav_col1, nav_col2, nav_col3 = st.columns(3) |
| 76 | + if nav_col1.button(f"Prev {SAMPLES_PER_PAGE}"): |
| 77 | + st.session_state.page = max(st.session_state.page - 1, 0) |
| 78 | + if nav_col2.button("Random"): |
| 79 | + st.session_state.page = random.randint(0, len(ds) // SAMPLES_PER_PAGE) |
| 80 | + if nav_col3.button(f"Next {SAMPLES_PER_PAGE}"): |
| 81 | + st.session_state.page = min( |
| 82 | + st.session_state.page + 1, len(ds) // SAMPLES_PER_PAGE |
| 83 | + ) |
| 84 | + |
| 85 | + # 現在のページのサンプルを表示 |
| 86 | + start_idx = st.session_state.page * SAMPLES_PER_PAGE |
| 87 | + end_idx = min(start_idx + SAMPLES_PER_PAGE, len(ds)) |
| 88 | + |
| 89 | + st.write(f"### Showing samples {start_idx + 1} to {end_idx} / {len(ds)}") |
| 90 | + |
| 91 | + # ヘッダー columnを表示 |
| 92 | + header_cols = st.columns(column_width_list) |
| 93 | + header_cols[0].markdown("ID") |
| 94 | + header_cols[1].markdown("Image") |
| 95 | + header_cols[2].markdown("Question") |
| 96 | + header_cols[3].markdown("Answer") |
| 97 | + for model_id in model_list: |
| 98 | + header_cols[4 + model_list.index(model_id)].markdown(f"Prediction ({model_id})") |
| 99 | + |
| 100 | + # サンプルを表示 |
| 101 | + for idx in range(start_idx, end_idx): |
| 102 | + with st.container(): |
| 103 | + show_sample(idx) |
| 104 | + st.markdown("---") |
0 commit comments