-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_academic_curves.py
More file actions
92 lines (75 loc) · 4.09 KB
/
plot_academic_curves.py
File metadata and controls
92 lines (75 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
# ==========================================
# 1. 录入你严谨测算出的消融实验数据
# ==========================================
thresholds = np.array([0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50])
dice = np.array([0.6681, 0.6905, 0.7253, 0.7153, 0.7244, 0.7302, 0.7292])
recall = np.array([0.8444, 0.8217, 0.6749, 0.7870, 0.7679, 0.7419, 0.7162]) # Sensitivity
specificity = np.array([0.9331, 0.9453, 0.9818, 0.9595, 0.9655, 0.9716, 0.9757])
# 通过严谨的数学公式反推 Precision
precision = (dice * recall) / (2 * recall - dice)
# ==========================================
# 2. 样条插值 (生成平滑学术曲线)
# ==========================================
x_smooth = np.linspace(thresholds.min(), thresholds.max(), 300)
spline_dice = make_interp_spline(thresholds, dice, k=3)
spline_recall = make_interp_spline(thresholds, recall, k=3)
spline_spec = make_interp_spline(thresholds, specificity, k=3)
dice_smooth = spline_dice(x_smooth)
recall_smooth = spline_recall(x_smooth)
spec_smooth = spline_spec(x_smooth)
# 获取 Dice 全局最优点
optimal_idx = np.argmax(dice_smooth)
optimal_thresh = x_smooth[optimal_idx]
optimal_dice = dice_smooth[optimal_idx]
# ==========================================
# 3. 绘制图表 1:Metrics vs. Threshold
# ==========================================
plt.style.use('seaborn-v0_8-whitegrid')
fig1, ax1 = plt.subplots(figsize=(8, 6), dpi=300)
ax1.plot(x_smooth, dice_smooth, label='Dice (F1-Score)', color='#d62728', linewidth=2.5)
ax1.plot(x_smooth, recall_smooth, label='Sensitivity (Recall)', color='#1f77b4', linewidth=2, linestyle='--')
ax1.plot(x_smooth, spec_smooth, label='Specificity', color='#2ca02c', linewidth=2, linestyle='-.')
# 标记真实数据点
ax1.scatter(thresholds, dice, color='#d62728', s=50, zorder=5)
ax1.scatter(thresholds, recall, color='#1f77b4', s=50, zorder=5)
ax1.scatter(thresholds, specificity, color='#2ca02c', s=50, zorder=5)
# 标记黄金切割点
ax1.axvline(x=optimal_thresh, color='gray', linestyle=':', linewidth=1.5)
ax1.text(optimal_thresh - 0.01, optimal_dice + 0.015, f'Optimal Thresh: {optimal_thresh:.2f}\nMax Dice: {optimal_dice:.4f}',
color='#d62728', fontweight='bold', ha='right')
ax1.set_xlabel('Prediction Threshold', fontsize=12, fontweight='bold')
ax1.set_ylabel('Metric Score', fontsize=12, fontweight='bold')
ax1.set_title('Impact of Threshold on Segmentation Performance', fontsize=14, fontweight='bold', pad=15)
ax1.legend(loc='lower right', frameon=True, shadow=True, fontsize=11)
ax1.set_ylim(0.6, 1.0)
fig1.tight_layout()
fig1.savefig('Metrics_vs_Threshold.png')
print("✅ 图表 1 (Metrics_vs_Threshold.png) 已生成!")
# ==========================================
# 4. 绘制图表 2:PR 曲线 (Precision-Recall)
# ==========================================
fig2, ax2 = plt.subplots(figsize=(7, 7), dpi=300)
# 对 PR 曲线也进行插值平滑处理
recall_sort_idx = np.argsort(recall)
recall_sorted = recall[recall_sort_idx]
precision_sorted = precision[recall_sort_idx]
spline_pr = make_interp_spline(recall_sorted, precision_sorted, k=2)
recall_pr_smooth = np.linspace(recall_sorted.min(), recall_sorted.max(), 300)
precision_pr_smooth = spline_pr(recall_pr_smooth)
ax2.plot(recall_pr_smooth, precision_pr_smooth, color='#9467bd', linewidth=3, label='Swin-UNet (V3)')
ax2.scatter(recall, precision, color='#8c564b', s=60, zorder=5, label='Ablation Points')
# 标记最高 Dice 对应的 PR 坐标
optimal_recall = recall[4] # 阈值 0.45 的下标
optimal_precision = precision[4]
ax2.plot(optimal_recall, optimal_precision, 'r*', markersize=15, label='Optimal F1-Score Point')
ax2.set_xlabel('Recall (Sensitivity)', fontsize=12, fontweight='bold')
ax2.set_ylabel('Precision', fontsize=12, fontweight='bold')
ax2.set_title('Precision-Recall Curve on DRIVE Dataset', fontsize=14, fontweight='bold', pad=15)
ax2.legend(loc='lower left', frameon=True, shadow=True, fontsize=11)
ax2.grid(True, linestyle='--', alpha=0.7)
fig2.tight_layout()
fig2.savefig('PR_Curve.png')
print("✅ 图表 2 (PR_Curve.png) 已生成!")