Skip to content

Commit 6166192

Browse files
committed
Lint
1 parent 9319df2 commit 6166192

File tree

3 files changed

+37
-57
lines changed

3 files changed

+37
-57
lines changed

machine_learning_hep/analysis/do_systematics.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
Author: Vit Kucera <[email protected]>
1616
"""
1717

18-
# pylint: disable=too-many-lines, too-many-instance-attributes, too-many-statements, too-many-locals
19-
# pylint: disable=too-many-nested-blocks, too-many-branches, consider-using-f-string
20-
2118
import argparse
2219
import logging
2320
import os
@@ -94,7 +91,7 @@ def __init__(self, path_database_analysis: str, typean: str):
9491

9592
with open(path_database_analysis, encoding="utf-8") as file_in:
9693
db_analysis = yaml.safe_load(file_in)
97-
case = list(db_analysis.keys())[0]
94+
case = next(iter(db_analysis.keys()))
9895
self.datap = db_analysis[case]
9996
self.db_typean = self.datap["analysis"][self.typean]
10097

@@ -790,28 +787,21 @@ def do_jet_systematics(self, var: str):
790787
count_sys_up = count_sys_up + 1
791788
else:
792789
error_var_up = max(error_var_up, error)
793-
else:
794-
if self.systematic_rms[sys_cat] is True:
795-
if self.systematic_rms_both_sides[sys_cat] is True:
796-
error_var_up += error * error
797-
if not out_sys:
798-
count_sys_up = count_sys_up + 1
799-
else:
800-
error_var_down += error * error
801-
if not out_sys:
802-
count_sys_down = count_sys_down + 1
790+
elif self.systematic_rms[sys_cat] is True:
791+
if self.systematic_rms_both_sides[sys_cat] is True:
792+
error_var_up += error * error
793+
if not out_sys:
794+
count_sys_up = count_sys_up + 1
803795
else:
804-
error_var_down = max(error_var_down, abs(error))
805-
if self.systematic_rms[sys_cat] is True:
806-
if count_sys_up != 0:
807-
error_var_up = error_var_up / count_sys_up
796+
error_var_down += error * error
797+
if not out_sys:
798+
count_sys_down = count_sys_down + 1
808799
else:
809-
error_var_up = 0.0
800+
error_var_down = max(error_var_down, abs(error))
801+
if self.systematic_rms[sys_cat] is True:
802+
error_var_up = error_var_up / count_sys_up if count_sys_up != 0 else 0.0
810803
error_var_up = sqrt(error_var_up)
811-
if count_sys_down != 0:
812-
error_var_down = error_var_down / count_sys_down
813-
else:
814-
error_var_down = 0.0
804+
error_var_down = error_var_down / count_sys_down if count_sys_down != 0 else 0.0
815805
if self.systematic_rms_both_sides[sys_cat] is True:
816806
error_var_down = error_var_up
817807
else:
@@ -918,7 +908,8 @@ def do_jet_systematics(self, var: str):
918908
shapebins_error_down_cat = []
919909
for ibinshape in range(n_bins_obs_gen):
920910
shapebins_contents_cat.append(0)
921-
if abs(input_histograms_default[iptjet].GetBinContent(ibinshape + 1)) < 1.0e-7:
911+
epsilon_float = 1.0e-7
912+
if abs(input_histograms_default[iptjet].GetBinContent(ibinshape + 1)) < epsilon_float:
922913
print("WARNING!!! Input histogram at bin", iptjet, " equal 0", suffix)
923914
e_up = 0
924915
e_down = 0
@@ -992,7 +983,8 @@ def do_jet_systematics(self, var: str):
992983
suffix = self.get_suffix_ptjet(iptjet)
993984
h_default_stat_err.append(input_histograms_default[iptjet].Clone("h_default_stat_err" + suffix))
994985
for i in range(h_default_stat_err[iptjet].GetNbinsX()):
995-
if abs(input_histograms_default[iptjet].GetBinContent(i + 1)) < 1.0e-7:
986+
epsilon_float = 1.0e-7
987+
if abs(input_histograms_default[iptjet].GetBinContent(i + 1)) < epsilon_float:
996988
print("WARNING!!! Input histogram at bin", iptjet, " equal 0", suffix)
997989
h_default_stat_err[iptjet].SetBinContent(i + 1, 0)
998990
h_default_stat_err[iptjet].SetBinError(i + 1, 0)

machine_learning_hep/plotting/plot_jetsubstructure_run3.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
Author: Vit Kucera <[email protected]>
1616
"""
1717

18-
# pylint: disable=too-many-lines, too-many-instance-attributes, too-many-statements, too-many-locals
19-
# pylint: disable=too-many-nested-blocks, too-many-branches, consider-using-f-string
20-
# pylint: disable=unused-variable
21-
2218
import argparse
2319
import logging
2420
import os
@@ -106,7 +102,7 @@ def __init__(self, path_input_file: str, path_database_analysis: str, typean: st
106102

107103
with open(path_database_analysis, encoding="utf-8") as file_db:
108104
db_analysis = yaml.safe_load(file_db)
109-
case = list(db_analysis.keys())[0]
105+
case = next(iter(db_analysis.keys()))
110106
self.datap = db_analysis[case]
111107
self.db_typean = self.datap["analysis"][self.typean]
112108

@@ -406,7 +402,7 @@ def get_run3_sim(self) -> dict:
406402
source = {"monash": "M", "mode2": "SM2"}
407403
dict_obj = {}
408404
with TFile.Open(path_file) as file:
409-
for s_obs, obs in obs.items():
405+
for s_obs, n_obs in obs.items():
410406
dict_obj[s_obs] = {}
411407
for s_spec, spec in species.items():
412408
dict_obj[s_obs][s_spec] = {}
@@ -418,7 +414,7 @@ def get_run3_sim(self) -> dict:
418414
name = pattern % (
419415
spec,
420416
src,
421-
obs,
417+
n_obs,
422418
self.edges_ptjet_gen[iptjet],
423419
self.edges_ptjet_gen[iptjet + 1],
424420
)
@@ -681,7 +677,7 @@ def plot(self):
681677
self.make_plot(f"{self.species}_efficiency_{self.var}")
682678

683679
bins_ptjet = (0, 1, 2, 3)
684-
for cat, label in zip(("pr", "np"), ("prompt", "non-prompt")):
680+
for cat, label in zip(("pr", "np"), ("prompt", "non-prompt"), strict=False):
685681
self.list_obj = self.get_objects(
686682
*(
687683
f"h_ptjet-pthf_effnew_{cat}_{string_range_ptjet(get_bin_limits(axis_ptjet, iptjet + 1))}"
@@ -1206,7 +1202,7 @@ def plot(self):
12061202
# TODO: if plot_run2_lc_ff_data
12071203
if h_run2 is not None:
12081204
n_obj = len(self.list_obj)
1209-
self.plot_order = list(range(n_obj)) + [-1, -0.5]
1205+
self.plot_order = [*list(range(n_obj)), -1, -0.5]
12101206
self.list_obj += [g_run2, h_run2]
12111207
self.labels_obj += [f"{self.text_run2}, {self.get_text_range_ptjet(2)}", ""]
12121208
self.list_colours += [get_colour(-1)] * 2

machine_learning_hep/utilities.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,7 @@ def write_df(dfo, path):
128128

129129
def read_df(path, **kwargs):
130130
try:
131-
if path.endswith(".parquet"):
132-
df = pd.read_parquet(path, **kwargs)
133-
else:
134-
df = pickle.load(openfile(path, "rb"))
131+
df = pd.read_parquet(path, **kwargs) if path.endswith(".parquet") else pickle.load(openfile(path, "rb"))
135132
except Exception as e: # pylint: disable=broad-except
136133
logger.critical("failed to open file <%s>: %s", path, str(e))
137134
sys.exit()
@@ -294,12 +291,12 @@ def make_latex_table(column_names, row_names, rows, caption=None, save_path="./t
294291
columns = "|".join(["c"] * (len(column_names) + 1))
295292
f.write("\\begin{tabular}{" + columns + "}\n")
296293
f.write("\\hline\n")
297-
columns = "&".join([""] + column_names)
294+
columns = "&".join(["", *column_names])
298295
columns = columns.replace("_", "\\_")
299296
f.write(columns + "\\\\\n")
300297
f.write("\\hline\\hline\n")
301-
for rn, row in zip(row_names, rows):
302-
row_string = "&".join([rn] + row)
298+
for rn, row in zip(row_names, rows, strict=False):
299+
row_string = "&".join([rn, *row])
303300
row_string = row_string.replace("_", "\\_")
304301
f.write(row_string + "\\\\\n")
305302
f.write("\\end{tabular}\n")
@@ -349,12 +346,12 @@ def make_message_notfound(name, location=None):
349346

350347

351348
def z_calc(pt_1, phi_1, eta_1, pt_2, phi_2, eta_2):
352-
np_pt_1 = pt_1.values
353-
np_pt_2 = pt_2.values
354-
np_phi_1 = phi_1.values
355-
np_phi_2 = phi_2.values
356-
np_eta_1 = eta_1.values
357-
np_eta_2 = eta_2.values
349+
np_pt_1 = pt_1.to_numpy()
350+
np_pt_2 = pt_2.to_numpy()
351+
np_phi_1 = phi_1.to_numpy()
352+
np_phi_2 = phi_2.to_numpy()
353+
np_eta_1 = eta_1.to_numpy()
354+
np_eta_2 = eta_2.to_numpy()
358355

359356
cos_phi_1 = np.cos(np_phi_1)
360357
cos_phi_2 = np.cos(np_phi_2)
@@ -397,10 +394,7 @@ def equal_axis_list(axis1, list2, precision=10):
397394
bins = get_bins(axis1)
398395
if len(bins) != len(list2):
399396
return False
400-
for i, j in zip(bins, list2):
401-
if round(i, precision) != round(j, precision):
402-
return False
403-
return True
397+
return all(round(i, precision) == round(j, precision) for i, j in zip(bins, list2, strict=False))
404398

405399

406400
def equal_binning(his1, his2):
@@ -418,9 +412,7 @@ def equal_binning_lists(his, list_x=None, list_y=None, list_z=None):
418412
return False
419413
if list_y is not None and not equal_axis_list(his.GetYaxis(), list_y):
420414
return False
421-
if list_z is not None and not equal_axis_list(his.GetZaxis(), list_z):
422-
return False
423-
return True
415+
return not (list_z is not None and not equal_axis_list(his.GetZaxis(), list_z))
424416

425417

426418
def folding(h_input, response_matrix, h_output):
@@ -939,13 +931,13 @@ def plot_latex(latex):
939931
# set canvas margins
940932
if isinstance(margins_c, list) and len(margins_c) > 0:
941933
for setter, value in zip(
942-
[can.SetBottomMargin, can.SetLeftMargin, can.SetTopMargin, can.SetRightMargin], margins_c
934+
[can.SetBottomMargin, can.SetLeftMargin, can.SetTopMargin, can.SetRightMargin], margins_c, strict=False
943935
):
944936
setter(value)
945937
# set logarithmic scale for selected axes
946938
log_y = False
947939
if isinstance(logscale, str) and len(logscale) > 0:
948-
for setter, axis in zip([can.SetLogx, can.SetLogy, can.SetLogz], ["x", "y", "z"]):
940+
for setter, axis in zip([can.SetLogx, can.SetLogy, can.SetLogz], ["x", "y", "z"], strict=False):
949941
if axis in logscale:
950942
setter()
951943
if axis == "y":
@@ -1349,7 +1341,7 @@ def format_value_with_unc(y, e_stat=None, e_syst_plus=None, e_syst_minus=None, n
13491341
if str_e_syst_plus == str_e_syst_minus:
13501342
str_value += f" ± {str_e_syst_plus} (syst.)"
13511343
else:
1352-
str_value += f" +{str_e_syst_plus}{str_e_syst_minus} (syst.)"
1344+
str_value += f" +{str_e_syst_plus}{str_e_syst_minus} (syst.)" # noqa: RUF001
13531345
return str_value
13541346

13551347

0 commit comments

Comments
 (0)