Skip to content

Commit 17fb7b5

Browse files
authored
Merge pull request #11 from LPC-HH/main
Update
2 parents 60aa367 + e6fe4d5 commit 17fb7b5

File tree

5 files changed

+377
-142
lines changed

5 files changed

+377
-142
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ jobs:
5050
--durations=20
5151
5252
- name: Upload coverage report
53-
uses: codecov/codecov-action@v5.3.1
53+
uses: codecov/codecov-action@v5.4.2
5454
with:
5555
token: ${{ secrets.CODECOV_TOKEN }}

src/boostedhh/plotting.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def ratioHistPlot(
550550
if ylim is not None:
551551
ax.set_ylim([y_lowlim, ylim])
552552
else:
553-
ax.set_ylim(y_lowlim)
553+
ax.set_ylim(y_lowlim, ax.get_ylim()[1] * 2)
554554

555555
ax.margins(x=0)
556556

@@ -787,7 +787,7 @@ def _find_nearest(array, value):
787787
return idx
788788

789789

790-
def multiROCCurveGrey(
790+
def multiROCCurveGreyOld(
791791
rocs: dict,
792792
sig_effs: list[float],
793793
plot_dir: Path,
@@ -858,6 +858,162 @@ def multiROCCurveGrey(
858858
]
859859

860860

861+
def multiROCCurveGrey(
862+
rocs: dict,
863+
sig_effs: list[float] = None,
864+
bkg_effs: list[float] = None,
865+
xlim=None,
866+
ylim=None,
867+
plot_dir: Path = None,
868+
name: str = "",
869+
show: bool = False,
870+
add_cms_label=False,
871+
legtitle: str = None,
872+
title: str = None,
873+
plot_thresholds: dict = None, # plot signal and bkg efficiency for a given discriminator threshold
874+
find_from_sigeff: dict = None, # find discriminator threshold that matches signal efficiency
875+
):
876+
"""Plot multiple ROC curves (e.g. train and test) + multiple signals"""
877+
if ylim is None:
878+
ylim = [1e-06, 1]
879+
if xlim is None:
880+
xlim = [0, 1]
881+
line_style = {"colors": "lightgrey", "linestyles": "dashed"}
882+
th_colours = ["cornflowerblue", "deepskyblue", "mediumblue", "cyan", "cadetblue"]
883+
eff_colours = ["lime", "aquamarine", "greenyellow"]
884+
885+
fig = plt.figure(figsize=(12, 12))
886+
ax = fig.gca()
887+
for roc_sigs in rocs.values():
888+
889+
# plots roc curves for each type of signal
890+
for roc in roc_sigs.values():
891+
892+
plt.plot(
893+
roc["tpr"],
894+
roc["fpr"],
895+
label=roc["label"],
896+
color=roc["color"],
897+
linewidth=2,
898+
)
899+
900+
# determines the point on the ROC curve that corresponds to the signal efficiency
901+
# plots a vertical and horizontal line to the point
902+
if sig_effs is not None:
903+
for sig_eff in sig_effs:
904+
y = roc["fpr"][np.searchsorted(roc["tpr"], sig_eff)]
905+
plt.hlines(y=y, xmin=0, xmax=sig_eff, **line_style)
906+
plt.vlines(x=sig_eff, ymin=0, ymax=y, **line_style)
907+
908+
# determines the point on the ROC curve that corresponds to the background efficiency
909+
# plots a vertical and horizontal line to the point
910+
if bkg_effs is not None:
911+
for bkg_eff in bkg_effs:
912+
x = roc["tpr"][np.searchsorted(roc["fpr"], bkg_eff)]
913+
plt.vlines(x=x, ymin=0, ymax=bkg_eff, **line_style)
914+
plt.hlines(y=bkg_eff, xmin=0, xmax=x, **line_style)
915+
916+
# plots points and lines on plot corresponding to classifier thresholds
917+
for roc_sigs in rocs.values():
918+
if plot_thresholds is None:
919+
break
920+
i_sigeff = 0
921+
i_th = 0
922+
for rockey, roc in roc_sigs.items():
923+
if rockey in plot_thresholds:
924+
pths = {th: [[], []] for th in plot_thresholds[rockey]}
925+
for th in plot_thresholds[rockey]:
926+
idx = _find_nearest(roc["thresholds"], th)
927+
pths[th][0].append(roc["tpr"][idx])
928+
pths[th][1].append(roc["fpr"][idx])
929+
for th in plot_thresholds[rockey]:
930+
plt.scatter(
931+
*pths[th],
932+
marker="o",
933+
s=40,
934+
label=rf"{rockey} > {th:.2f}",
935+
zorder=100,
936+
color=th_colours[i_th],
937+
)
938+
plt.vlines(
939+
x=pths[th][0],
940+
ymin=0,
941+
ymax=pths[th][1],
942+
color=th_colours[i_th],
943+
linestyles="dashed",
944+
alpha=0.5,
945+
)
946+
plt.hlines(
947+
y=pths[th][1],
948+
xmin=0,
949+
xmax=pths[th][0],
950+
color=th_colours[i_th],
951+
linestyles="dashed",
952+
alpha=0.5,
953+
)
954+
i_th += 1
955+
956+
if find_from_sigeff is not None and rockey in find_from_sigeff:
957+
pths = {sig_eff: [[], []] for sig_eff in find_from_sigeff[rockey]}
958+
thrs = {}
959+
for sig_eff in find_from_sigeff[rockey]:
960+
idx = _find_nearest(roc["tpr"], sig_eff)
961+
thrs[sig_eff] = roc["thresholds"][idx]
962+
pths[sig_eff][0].append(roc["tpr"][idx])
963+
pths[sig_eff][1].append(roc["fpr"][idx])
964+
for sig_eff in find_from_sigeff[rockey]:
965+
plt.scatter(
966+
*pths[sig_eff],
967+
marker="o",
968+
s=40,
969+
label=rf"{rockey} > {thrs[sig_eff]:.2f}",
970+
zorder=100,
971+
color=eff_colours[i_sigeff],
972+
)
973+
plt.vlines(
974+
x=pths[sig_eff][0],
975+
ymin=0,
976+
ymax=pths[sig_eff][1],
977+
color=eff_colours[i_sigeff],
978+
linestyles="dashed",
979+
alpha=0.5,
980+
)
981+
plt.hlines(
982+
y=pths[sig_eff][1],
983+
xmin=0,
984+
xmax=pths[sig_eff][0],
985+
color=eff_colours[i_sigeff],
986+
linestyles="dashed",
987+
alpha=0.5,
988+
)
989+
i_sigeff += 1
990+
991+
if add_cms_label:
992+
hep.cms.label(data=False, rlabel="")
993+
if title:
994+
plt.title(title)
995+
plt.yscale("log")
996+
plt.xlabel("Signal efficiency")
997+
plt.ylabel("Background efficiency")
998+
plt.xlim(*xlim)
999+
plt.ylim(*ylim)
1000+
ax.xaxis.grid(True, which="major")
1001+
ax.yaxis.grid(True, which="major")
1002+
if legtitle:
1003+
plt.legend(title=legtitle, loc="center left", bbox_to_anchor=(1, 0.5))
1004+
else:
1005+
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1006+
1007+
if len(name):
1008+
plt.savefig(plot_dir / f"{name}.png", bbox_inches="tight")
1009+
plt.savefig(plot_dir / f"{name}.pdf", bbox_inches="tight")
1010+
1011+
if show:
1012+
plt.show()
1013+
else:
1014+
plt.close()
1015+
1016+
8611017
def multiROCCurve(
8621018
rocs: dict,
8631019
thresholds=None,

src/boostedhh/run_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def parse_common_hh_args(parser):
109109
"--year",
110110
help="year",
111111
type=str,
112-
default="2022",
112+
nargs="+",
113+
required=True,
113114
choices=["2018", "2022", "2022EE", "2023", "2023BPix"],
114115
)
115116

src/boostedhh/submit_utils.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def check_branch(
133133
def init_args(args):
134134
# check that branch exists
135135
check_branch(args.analysis, args.git_branch, args.git_user, args.allow_diff_local_repo)
136+
137+
if isinstance(args.year, list):
138+
if len(args.year) == 1:
139+
args.year = args.year[0]
140+
else:
141+
raise ValueError("Submitting multiple years without --yaml option is not supported yet")
142+
136143
username = os.environ["USER"]
137144

138145
if args.site == "lpc":
@@ -189,8 +196,8 @@ def submit(
189196

190197
# submit jobs
191198
nsubmit = 0
192-
for sample in fileset:
193-
for subsample, tot_files in fileset[sample].items():
199+
for sample, sfiles in fileset.items():
200+
for subsample, tot_files in sfiles.items():
194201
if args.submit:
195202
print("Submitting " + subsample)
196203

@@ -239,7 +246,7 @@ def submit(
239246
Path(f"{localcondor}.log").unlink()
240247

241248
if args.submit:
242-
os.system("condor_submit %s" % localcondor)
249+
os.system(f"condor_submit {localcondor}")
243250
else:
244251
print("To submit ", localcondor)
245252
nsubmit = nsubmit + 1
@@ -250,14 +257,14 @@ def submit(
250257
def replace_batch_size(file_path: Path, new_batch_size: int):
251258
"""Replacing batch size in given file"""
252259
import re
253-
260+
254261
# Read the file
255-
with file_path.open('r') as file:
262+
with file_path.open("r") as file:
256263
content = file.read()
257-
264+
258265
# Replace using regex
259-
updated_content = re.sub(r'--batch-size \d+', f'--batch-size {new_batch_size}', content)
260-
266+
updated_content = re.sub(r"--batch-size \d+", f"--batch-size {new_batch_size}", content)
267+
261268
# Write back to the file
262-
with open(file_path, 'w') as file:
263-
file.write(updated_content)
269+
with file_path.open("w") as file:
270+
file.write(updated_content)

0 commit comments

Comments
 (0)