Skip to content

Commit 24e9846

Browse files
committed
Fix formatting
1 parent 3039746 commit 24e9846

File tree

2 files changed

+44
-28
lines changed

2 files changed

+44
-28
lines changed

machine_learning_hep/multiprocesser.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@
1616
main script for doing data processing, machine learning and analysis
1717
"""
1818

19-
from functools import reduce
2019
import os
2120
import tempfile
21+
from functools import reduce
2222
from typing import TypeVar
2323

2424
from machine_learning_hep.io_ml_utils import dump_yaml_from_dict, parse_yaml
2525
from machine_learning_hep.logger import get_logger
2626
from machine_learning_hep.utilities import merge_method, mergerootfiles
27+
2728
from .common import DataType
2829

30+
2931
class MultiProcesser: # pylint: disable=too-many-instance-attributes, too-many-statements, consider-using-f-string, too-many-branches
3032
species = "multiprocesser"
3133
logger = get_logger()
@@ -111,8 +113,12 @@ def __init__(self, case, proc_class, datap, typean, run_param, datatype):
111113
self.lper_evtorig = [os.path.join(direc, self.n_evtorig) for direc in self.dlper_pkl]
112114

113115
dp = self.cfg(f"mlapplication.{self.datatype.value}", {})
114-
self.dlper_reco_modapp = [self.d_prefix_app + p for p in dp["pkl_skimmed_dec"]] if dp else [None] * len(self.p_period)
115-
self.dlper_reco_modappmerged = [self.d_prefix_app + p for p in dp["pkl_skimmed_decmerged"]] if dp else [None] * len(self.p_period)
116+
self.dlper_reco_modapp = (
117+
[self.d_prefix_app + p for p in dp["pkl_skimmed_dec"]] if dp else [None] * len(self.p_period)
118+
)
119+
self.dlper_reco_modappmerged = (
120+
[self.d_prefix_app + p for p in dp["pkl_skimmed_decmerged"]] if dp else [None] * len(self.p_period)
121+
)
116122

117123
dp = self.cfg(f"analysis.{self.typean}.{self.datatype.value}", {})
118124
self.d_results = [self.d_prefix_res + os.path.expandvars(p) for p in dp["results"]]
@@ -121,7 +127,9 @@ def __init__(self, case, proc_class, datap, typean, run_param, datatype):
121127
self.f_evt_mergedallp = os.path.join(self.d_pklevt_mergedallp, self.n_evt)
122128
self.f_evtorig_mergedallp = os.path.join(self.d_pklevt_mergedallp, self.n_evtorig)
123129

124-
self.lper_runlistrigger = self.cfg(f"analysis.{self.typean}.{self.datatype.value}.runselection", [None] * len(self.p_period))
130+
self.lper_runlistrigger = self.cfg(
131+
f"analysis.{self.typean}.{self.datatype.value}.runselection", [None] * len(self.p_period)
132+
)
125133

126134
self.lper_mcreweights = None
127135
if self.datatype == DataType.MC:

machine_learning_hep/processer.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,18 @@ def __init__(
219219
# Potentially mask certain values (e.g. nsigma TOF of -999)
220220
self.p_mask_values = datap["ml"].get("mask_values", None)
221221

222-
self.bins_skimming = np.array(list(zip(self.lpt_anbinmin, self.lpt_anbinmax, strict=False)), "d")
223-
self.bins_analysis = np.array(list(zip(self.lpt_finbinmin, self.lpt_finbinmax, strict=False)), "d")
222+
self.bins_skimming = np.array(list(zip(self.lpt_anbinmin, self.lpt_anbinmax, strict=True)), "d")
223+
self.bins_analysis = np.array(list(zip(self.lpt_finbinmin, self.lpt_finbinmax, strict=True)), "d")
224224
bin_matching = [
225225
[ptrange[0] <= bin[0] and ptrange[1] >= bin[1] for ptrange in self.bins_skimming].index(True)
226226
for bin in self.bins_analysis
227227
]
228228

229229
self.lpt_probcutpre = self.cfg_global(f"mlapplication.probcutpresel.{self.datatype}", [None] * self.p_nptbins)
230-
lpt_probcutfin_tmp = self.cfg_global(f"mlapplication.probcutoptimal", [None] * self.p_nptfinbins)
230+
lpt_probcutfin_tmp = self.cfg_global("mlapplication.probcutoptimal", [None] * self.p_nptfinbins)
231231
self.lpt_probcutfin = [lpt_probcutfin_tmp[bin_matching[ibin]] for ibin in range(self.p_nptfinbins)]
232232

233-
if self.datatype in ('mc', 'data'):
233+
if self.datatype in ("mc", "data"):
234234
for ibin, probcutfin in enumerate(self.lpt_probcutfin):
235235
probcutpre = self.lpt_probcutpre[bin_matching[ibin]]
236236
if self.mltype == "MultiClassification":
@@ -254,7 +254,7 @@ def __init__(
254254
for ipt in range(self.p_nptfinbins):
255255
mlsel_multi = [
256256
f"y_test_prob{self.p_modelname}{label.replace('-', '_')} {comp} {probcut}"
257-
for label, comp, probcut in zip(self.class_labels, comps, self.lpt_probcutfin[ipt], strict=False)
257+
for label, comp, probcut in zip(self.class_labels, comps, self.lpt_probcutfin[ipt], strict=True)
258258
]
259259

260260
self.d_pkl_dec = d_pkl_dec
@@ -291,7 +291,7 @@ def __init__(
291291
)
292292

293293
self.lpt_recodec = None
294-
if self.doml and self.datatype in ('mc', 'data'):
294+
if self.doml and self.datatype in ("mc", "data"):
295295
if self.mltype == "MultiClassification":
296296
self.lpt_recodec = [
297297
self.n_reco.replace(
@@ -320,22 +320,30 @@ def __init__(
320320
for i in range(self.p_nptbins)
321321
]
322322

323-
self.mptfiles_recosk = [
324-
createlist(self.d_pklsk, self.l_path, self.lpt_recosk[ipt]) for ipt in range(self.p_nptbins)
325-
] if self.datatype in ('mc', 'data') else []
326-
self.mptfiles_recoskmldec = [
327-
createlist(self.d_pkl_dec, self.l_path, self.lpt_recodec[ipt]) for ipt in range(self.p_nptbins)
328-
] if self.datatype in ('mc', 'data') else []
329-
self.lpt_recodecmerged = [
330-
os.path.join(self.d_pkl_decmerged, self.lpt_recodec[ipt]) for ipt in range(self.p_nptbins)
331-
] if self.datatype in ('mc', 'data') else []
332-
if self.datatype in ('mc', 'fd'):
323+
self.mptfiles_recosk = (
324+
[createlist(self.d_pklsk, self.l_path, self.lpt_recosk[ipt]) for ipt in range(self.p_nptbins)]
325+
if self.datatype in ("mc", "data")
326+
else []
327+
)
328+
self.mptfiles_recoskmldec = (
329+
[createlist(self.d_pkl_dec, self.l_path, self.lpt_recodec[ipt]) for ipt in range(self.p_nptbins)]
330+
if self.datatype in ("mc", "data")
331+
else []
332+
)
333+
self.lpt_recodecmerged = (
334+
[os.path.join(self.d_pkl_decmerged, self.lpt_recodec[ipt]) for ipt in range(self.p_nptbins)]
335+
if self.datatype in ("mc", "data")
336+
else []
337+
)
338+
if self.datatype in ("mc", "fd"):
333339
self.mptfiles_gensk = [
334340
createlist(self.d_pklsk, self.l_path, self.lpt_gensk[ipt]) for ipt in range(self.p_nptbins)
335341
]
336-
self.lpt_gendecmerged = [
337-
os.path.join(self.d_pkl_decmerged, self.lpt_gensk[ipt]) for ipt in range(self.p_nptbins)
338-
] if self.d_pkl_decmerged else []
342+
self.lpt_gendecmerged = (
343+
[os.path.join(self.d_pkl_decmerged, self.lpt_gensk[ipt]) for ipt in range(self.p_nptbins)]
344+
if self.d_pkl_decmerged
345+
else []
346+
)
339347
self.mptfiles_gensk_sl = (
340348
[createlist(self.d_pklsk, self.l_path, self.lpt_gensk_sl[ipt]) for ipt in range(self.p_nptbins)]
341349
if self.lpt_gensk_sl
@@ -378,10 +386,10 @@ def dfread(rdir, trees, cols, idx_name=None):
378386
trees = [trees]
379387
cols = [cols]
380388
# if all(type(var) is str for var in vars): vars = [vars]
381-
if not all((name in rdir for name in trees)):
389+
if not all(name in rdir for name in trees):
382390
self.logger.critical("Missing trees: %s", trees)
383391
df = None
384-
for tree, col in zip([rdir[name] for name in trees], cols, strict=False):
392+
for tree, col in zip([rdir[name] for name in trees], cols, strict=True):
385393
try:
386394
data = tree.arrays(expressions=col, library="np")
387395
dfnew = pd.DataFrame(columns=col, data=data)
@@ -448,7 +456,7 @@ def dfuse(df_spec):
448456
if dfuse(df_spec):
449457
trees = []
450458
cols = []
451-
for tree, spec in zip(df_spec["trees"].keys(), df_spec["trees"].values(), strict=False):
459+
for tree, spec in zip(df_spec["trees"].keys(), df_spec["trees"].values(), strict=True):
452460
if isinstance(spec, list):
453461
trees.append(tree)
454462
cols.append(spec)
@@ -547,8 +555,8 @@ def dfuse(df_spec):
547555

548556
def skim(self, file_index):
549557
dfreco = read_df(self.l_reco[file_index]) if self.datatype != "fd" else None
550-
dfgen = read_df(self.l_gen[file_index]) if self.datatype in ('mc', 'fd') else None
551-
dfgen_sl = read_df(self.l_gen_sl[file_index]) if self.n_gen_sl and self.datatype in ('mc', 'fd') else None
558+
dfgen = read_df(self.l_gen[file_index]) if self.datatype in ("mc", "fd") else None
559+
dfgen_sl = read_df(self.l_gen_sl[file_index]) if self.n_gen_sl and self.datatype in ("mc", "fd") else None
552560

553561
for ipt in range(self.p_nptbins):
554562
if dfreco is not None:

0 commit comments

Comments
 (0)