|
| 1 | +import os |
| 2 | +import json |
| 3 | +import pandas as pd |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import japanize_matplotlib |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +# ====================================== |
| 9 | +# モデルごとのjsonlファイルパスを設定 |
| 10 | +# ====================================== |
| 11 | +vis_dir = "logs/mecha-ja/visualize/" |
| 12 | +root_dir = "logs/mecha-ja/prediction/" |
| 13 | +files = os.listdir(root_dir) |
| 14 | +model_names = [ |
| 15 | + os.path.basename(f).replace(".jsonl", "") for f in files if f.endswith(".jsonl") |
| 16 | +] |
| 17 | + |
| 18 | +model_files = { |
| 19 | + model_name: os.path.join(root_dir, f"{model_name}.jsonl") |
| 20 | + for model_name in model_names |
| 21 | +} |
| 22 | + |
| 23 | +model_dfs = {} |
| 24 | + |
| 25 | + |
| 26 | +# ====================================== |
| 27 | +# JSONLを読み込み、DataFrame化する関数 |
| 28 | +# ====================================== |
| 29 | +def load_jsonl_to_df(file_path): |
| 30 | + data_list = [] |
| 31 | + with open(file_path, "r", encoding="utf-8") as f: |
| 32 | + for line in f: |
| 33 | + data_list.append(json.loads(line.strip())) |
| 34 | + return pd.DataFrame(data_list) |
| 35 | + |
| 36 | + |
| 37 | +# ====================================== |
| 38 | +# データ読み込み & 予測列の追加 |
| 39 | +# ====================================== |
| 40 | +def check_abcd(text): |
| 41 | + letters = ["A", "B", "C", "D"] |
| 42 | + found = [ |
| 43 | + ch for ch in letters if ch in text |
| 44 | + ] # テキスト中に含まれる A/B/C/D をリスト化 |
| 45 | + # 含まれている文字がちょうど1つならその文字、そうでなければ F を返す |
| 46 | + return found[0] if len(found) == 1 else "F" |
| 47 | + |
| 48 | + |
| 49 | +for model_name, file_path in model_files.items(): |
| 50 | + df = load_jsonl_to_df(file_path) |
| 51 | + df["pred"] = df["text"].apply(check_abcd) |
| 52 | + model_dfs[model_name] = df |
| 53 | + |
| 54 | +rotate_map = { |
| 55 | + "A": ["A", "D", "C", "B"], |
| 56 | + "B": ["B", "A", "D", "C"], |
| 57 | + "C": ["C", "B", "A", "D"], |
| 58 | + "D": ["D", "C", "B", "A"], |
| 59 | +} |
| 60 | + |
| 61 | +# ====================================== |
| 62 | +# (1) モデルごとの回答選択肢の分布を可視化(相対頻度) |
| 63 | +# ====================================== |
| 64 | +# 3 x 3 のグリッドで可視化 |
| 65 | +fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 10), sharex=True, sharey=True) |
| 66 | + |
| 67 | +# axesを1次元にする |
| 68 | +axes = axes.flatten() |
| 69 | + |
| 70 | +for ax, (model_name, df) in zip(axes, model_dfs.items()): |
| 71 | + # 予測回答の分布(相対頻度)をカウント |
| 72 | + pred_counts = ( |
| 73 | + df["pred"].value_counts().reindex(["A", "B", "C", "D", "F"], fill_value=0) |
| 74 | + ) |
| 75 | + pred_counts = pred_counts / pred_counts.sum() # 相対頻度に変換 |
| 76 | + |
| 77 | + ax.bar( |
| 78 | + pred_counts.index, |
| 79 | + pred_counts.values, |
| 80 | + color=["#FF9999", "#FFE888", "#99FF99", "#99CCFF", "#CCCCCC"], |
| 81 | + ) |
| 82 | + |
| 83 | + # 0.25 に赤い線を引く |
| 84 | + ax.axhline(y=0.25, color="r", linestyle="--", linewidth=1) |
| 85 | + |
| 86 | + ax.set_title(f"{model_name}") |
| 87 | + ax.set_xlabel("選択肢") |
| 88 | + ax.set_ylabel("選択頻度") |
| 89 | + |
| 90 | +plt.tight_layout() |
| 91 | +plt.savefig(os.path.join(vis_dir, "prediction_distribution.png")) |
| 92 | +plt.close() |
| 93 | + |
| 94 | +# ====================================== |
| 95 | +# (2) soft accuracy, strict accuracy, consistency の計算 |
| 96 | +# -------------------------------------- |
| 97 | +# ・soft accuracy: 単純に df["mecha-ja"] の平均値 |
| 98 | +# ・strict accuracy: 同じ問題 (q_base) で全て mecha-ja==1 なら正解とカウント |
| 99 | +# ・consistency: rot0 の予測に応じて rotate_map 通りになっている割合 |
| 100 | +# ====================================== |
| 101 | +results = [] |
| 102 | + |
| 103 | +for model_name, df in model_dfs.items(): |
| 104 | + # soft accuracy |
| 105 | + soft_accuracy = df["mecha-ja"].mean() # 1と0の平均 = 正解率 |
| 106 | + |
| 107 | + # "X_rotY" の "X" 部分を q_base として抽出 |
| 108 | + df["q_base"] = df["question_id"].apply(lambda x: x.split("_rot")[0]) |
| 109 | + # 回転番号を取得 |
| 110 | + df["rot"] = df["question_id"].apply(lambda x: int(x.split("_rot")[1])) |
| 111 | + |
| 112 | + grouped = df.groupby("q_base") |
| 113 | + unique_questions = df["q_base"].unique() |
| 114 | + |
| 115 | + correct_count = 0 # strict用 |
| 116 | + consistent_count = 0 |
| 117 | + |
| 118 | + for q_id, group in grouped: |
| 119 | + # strict正答率 (全rotで mecha-ja == 1) |
| 120 | + if all(group["mecha-ja"] == 1): |
| 121 | + correct_count += 1 |
| 122 | + |
| 123 | + # 一貫性 (rotate_map に従っているか) |
| 124 | + group_sorted = group.sort_values("rot") |
| 125 | + preds = group_sorted["pred"].tolist() # rot0→rot1→rot2→rot3 の順 |
| 126 | + |
| 127 | + pred_rot0 = preds[0] # 最初が rot0 |
| 128 | + if pred_rot0 in rotate_map: |
| 129 | + expected_sequence = rotate_map[pred_rot0] |
| 130 | + # 一貫しているかどうか |
| 131 | + if len(preds) == 4 and preds == expected_sequence: |
| 132 | + consistent_count += 1 |
| 133 | + |
| 134 | + strict_accuracy = ( |
| 135 | + correct_count / len(unique_questions) if len(unique_questions) else 0 |
| 136 | + ) |
| 137 | + consistency = ( |
| 138 | + consistent_count / len(unique_questions) if len(unique_questions) else 0 |
| 139 | + ) |
| 140 | + |
| 141 | + results.append( |
| 142 | + { |
| 143 | + "model": model_name, |
| 144 | + "soft_accuracy": soft_accuracy, |
| 145 | + "strict_accuracy": strict_accuracy, |
| 146 | + "consistency": consistency, |
| 147 | + } |
| 148 | + ) |
| 149 | + |
| 150 | +results_df = pd.DataFrame(results) |
| 151 | +print(results_df) |
| 152 | + |
| 153 | +# ====================================== |
| 154 | +# (3) 3種の指標 (soft, strict, consistency) を棒グラフで比較 |
| 155 | +# ====================================== |
| 156 | +metrics = ["soft_accuracy", "strict_accuracy", "consistency"] |
| 157 | +x = np.arange(len(results_df)) # モデル数 |
| 158 | +width = 0.25 # 棒の幅 |
| 159 | + |
| 160 | +fig, ax = plt.subplots(figsize=(8, 8)) |
| 161 | + |
| 162 | +ax.bar(x, results_df["soft_accuracy"], width=width, label="Soft Accuracy", alpha=0.7) |
| 163 | +ax.bar( |
| 164 | + x - width, |
| 165 | + results_df["strict_accuracy"], |
| 166 | + width=width, |
| 167 | + label="Strict Accuracy", |
| 168 | + alpha=0.7, |
| 169 | +) |
| 170 | +ax.bar( |
| 171 | + x + width + 0.05, |
| 172 | + results_df["consistency"], |
| 173 | + width=width, |
| 174 | + label="Consistency", |
| 175 | + alpha=0.7, |
| 176 | +) |
| 177 | + |
| 178 | +ax.set_xticks(x) |
| 179 | +ax.set_xticklabels(results_df["model"], rotation=90) |
| 180 | +ax.set_ylim(0, 1) |
| 181 | +ax.set_ylabel("Rate") |
| 182 | +ax.set_title("Soft/Strict Accuracy & Consistency by Model") |
| 183 | +ax.legend() |
| 184 | + |
| 185 | +plt.tight_layout() |
| 186 | +plt.savefig(os.path.join(vis_dir, "accuracy_consistency_comparison.png")) |
| 187 | +plt.close() |
0 commit comments