Skip to content

Commit 0b73db5

Browse files
committed
revert multiroccurvegrey
1 parent 2869780 commit 0b73db5

File tree

1 file changed

+157
-1
lines changed

1 file changed

+157
-1
lines changed

src/boostedhh/plotting.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def _find_nearest(array, value):
787787
return idx
788788

789789

790-
def multiROCCurveGrey(
790+
def multiROCCurveGreyOld(
791791
rocs: dict,
792792
sig_effs: list[float],
793793
plot_dir: Path,
@@ -858,6 +858,162 @@ def multiROCCurveGrey(
858858
]
859859

860860

861+
def multiROCCurveGrey(
862+
rocs: dict,
863+
sig_effs: list[float] = None,
864+
bkg_effs: list[float] = None,
865+
xlim=None,
866+
ylim=None,
867+
plot_dir: Path = None,
868+
name: str = "",
869+
show: bool = False,
870+
add_cms_label=False,
871+
legtitle: str = None,
872+
title: str = None,
873+
plot_thresholds: dict = None, # plot signal and bkg efficiency for a given discriminator threshold
874+
find_from_sigeff: dict = None, # find discriminator threshold that matches signal efficiency
875+
):
876+
"""Plot multiple ROC curves (e.g. train and test) + multiple signals"""
877+
if ylim is None:
878+
ylim = [1e-06, 1]
879+
if xlim is None:
880+
xlim = [0, 1]
881+
line_style = {"colors": "lightgrey", "linestyles": "dashed"}
882+
th_colours = ["cornflowerblue", "deepskyblue", "mediumblue", "cyan", "cadetblue"]
883+
eff_colours = ["lime", "aquamarine", "greenyellow"]
884+
885+
fig = plt.figure(figsize=(12, 12))
886+
ax = fig.gca()
887+
for roc_sigs in rocs.values():
888+
889+
# plots roc curves for each type of signal
890+
for roc in roc_sigs.values():
891+
892+
plt.plot(
893+
roc["tpr"],
894+
roc["fpr"],
895+
label=roc["label"],
896+
color=roc["color"],
897+
linewidth=2,
898+
)
899+
900+
# determines the point on the ROC curve that corresponds to the signal efficiency
901+
# plots a vertical and horizontal line to the point
902+
if sig_effs is not None:
903+
for sig_eff in sig_effs:
904+
y = roc["fpr"][np.searchsorted(roc["tpr"], sig_eff)]
905+
plt.hlines(y=y, xmin=0, xmax=sig_eff, **line_style)
906+
plt.vlines(x=sig_eff, ymin=0, ymax=y, **line_style)
907+
908+
# determines the point on the ROC curve that corresponds to the background efficiency
909+
# plots a vertical and horizontal line to the point
910+
if bkg_effs is not None:
911+
for bkg_eff in bkg_effs:
912+
x = roc["tpr"][np.searchsorted(roc["fpr"], bkg_eff)]
913+
plt.vlines(x=x, ymin=0, ymax=bkg_eff, **line_style)
914+
plt.hlines(y=bkg_eff, xmin=0, xmax=x, **line_style)
915+
916+
# plots points and lines on plot corresponding to classifier thresholds
917+
for roc_sigs in rocs.values():
918+
if plot_thresholds is None:
919+
break
920+
i_sigeff = 0
921+
i_th = 0
922+
for rockey, roc in roc_sigs.items():
923+
if rockey in plot_thresholds:
924+
pths = {th: [[], []] for th in plot_thresholds[rockey]}
925+
for th in plot_thresholds[rockey]:
926+
idx = _find_nearest(roc["thresholds"], th)
927+
pths[th][0].append(roc["tpr"][idx])
928+
pths[th][1].append(roc["fpr"][idx])
929+
for th in plot_thresholds[rockey]:
930+
plt.scatter(
931+
*pths[th],
932+
marker="o",
933+
s=40,
934+
label=rf"{rockey} > {th:.2f}",
935+
zorder=100,
936+
color=th_colours[i_th],
937+
)
938+
plt.vlines(
939+
x=pths[th][0],
940+
ymin=0,
941+
ymax=pths[th][1],
942+
color=th_colours[i_th],
943+
linestyles="dashed",
944+
alpha=0.5,
945+
)
946+
plt.hlines(
947+
y=pths[th][1],
948+
xmin=0,
949+
xmax=pths[th][0],
950+
color=th_colours[i_th],
951+
linestyles="dashed",
952+
alpha=0.5,
953+
)
954+
i_th += 1
955+
956+
if find_from_sigeff is not None and rockey in find_from_sigeff:
957+
pths = {sig_eff: [[], []] for sig_eff in find_from_sigeff[rockey]}
958+
thrs = {}
959+
for sig_eff in find_from_sigeff[rockey]:
960+
idx = _find_nearest(roc["tpr"], sig_eff)
961+
thrs[sig_eff] = roc["thresholds"][idx]
962+
pths[sig_eff][0].append(roc["tpr"][idx])
963+
pths[sig_eff][1].append(roc["fpr"][idx])
964+
for sig_eff in find_from_sigeff[rockey]:
965+
plt.scatter(
966+
*pths[sig_eff],
967+
marker="o",
968+
s=40,
969+
label=rf"{rockey} > {thrs[sig_eff]:.2f}",
970+
zorder=100,
971+
color=eff_colours[i_sigeff],
972+
)
973+
plt.vlines(
974+
x=pths[sig_eff][0],
975+
ymin=0,
976+
ymax=pths[sig_eff][1],
977+
color=eff_colours[i_sigeff],
978+
linestyles="dashed",
979+
alpha=0.5,
980+
)
981+
plt.hlines(
982+
y=pths[sig_eff][1],
983+
xmin=0,
984+
xmax=pths[sig_eff][0],
985+
color=eff_colours[i_sigeff],
986+
linestyles="dashed",
987+
alpha=0.5,
988+
)
989+
i_sigeff += 1
990+
991+
if add_cms_label:
992+
hep.cms.label(data=False, rlabel="")
993+
if title:
994+
plt.title(title)
995+
plt.yscale("log")
996+
plt.xlabel("Signal efficiency")
997+
plt.ylabel("Background efficiency")
998+
plt.xlim(*xlim)
999+
plt.ylim(*ylim)
1000+
ax.xaxis.grid(True, which="major")
1001+
ax.yaxis.grid(True, which="major")
1002+
if legtitle:
1003+
plt.legend(title=legtitle, loc="center left", bbox_to_anchor=(1, 0.5))
1004+
else:
1005+
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1006+
1007+
if len(name):
1008+
plt.savefig(plot_dir / f"{name}.png", bbox_inches="tight")
1009+
plt.savefig(plot_dir / f"{name}.pdf", bbox_inches="tight")
1010+
1011+
if show:
1012+
plt.show()
1013+
else:
1014+
plt.close()
1015+
1016+
8611017
def multiROCCurve(
8621018
rocs: dict,
8631019
thresholds=None,

0 commit comments

Comments
 (0)