Skip to content

Commit 293e78c

Browse files
authored
Simplify per ptjet fitting logic (alisw#1010)
* Simplify per ptjet fitting logic * Clean up jet analyzer * Fall back to inclusive ptjet bin for sidesub * Clean up jet analyzer
1 parent 3b06b0c commit 293e78c

File tree

1 file changed

+66
-104
lines changed

1 file changed

+66
-104
lines changed

machine_learning_hep/analysis/analyzer_jets.py

Lines changed: 66 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,13 @@ def __init__(self, datap, case, typean, period):
6262
super().__init__(datap, case, typean, period)
6363

6464
# output directories
65-
self.d_resultsallpmc = self.cfg(f"mc.results.{period}") if period is not None else self.cfg("mc.resultsallp")
66-
self.d_resultsallpdata = (
67-
self.cfg(f"data.results.{period}") if period is not None else self.cfg("data.resultsallp")
68-
)
65+
suffix = f"results.{period}" if period is not None else "resultsallp"
66+
self.d_resultsallpmc = self.cfg(f"mc.{suffix}")
67+
self.d_resultsallpdata = self.cfg(f"data.{suffix}")
6968

7069
# input directories (processor output)
71-
self.d_resultsallpmc_proc = self.d_resultsallpmc
72-
self.d_resultsallpdata_proc = self.d_resultsallpdata
73-
# use a different processor output
74-
if "data_proc" in datap["analysis"][typean]:
75-
self.d_resultsallpdata_proc = (
76-
self.cfg(f"data_proc.results.{period}") if period is not None else self.cfg("data_proc.resultsallp")
77-
)
78-
if "mc_proc" in datap["analysis"][typean]:
79-
self.d_resultsallpmc_proc = (
80-
self.cfg(f"mc_proc.results.{period}") if period is not None else self.cfg("mc_proc.resultsallp")
81-
)
70+
self.d_resultsallpdata_proc = self.cfg(f"data_proc.{suffix}")
71+
self.d_resultsallpmc_proc = self.cfg(f"mc_proc.{suffix}")
8272

8373
# input files
8474
n_filemass_name = datap["files_names"]["histofilename"]
@@ -93,6 +83,7 @@ def __init__(self, datap, case, typean, period):
9383
self.p_pdfnames = datap["analysis"][self.typean].get("pdf_names")
9484
self.p_param_names = datap["analysis"][self.typean].get("param_names")
9585

86+
# TODO: should come entirely from DB
9687
self.observables = {
9788
"qa": ["zg", "rg", "nsd", "zpar", "dr", "lntheta", "lnkt", "lntheta-lnkt"],
9889
"all": [*self.cfg("observables", {})],
@@ -104,7 +95,6 @@ def __init__(self, datap, case, typean, period):
10495
self.fit_levels = self.cfg("fit_levels", ["mc", "data"])
10596
self.fit_sigma = {}
10697
self.fit_mean = {}
107-
self.fit_func_bkg = {}
10898
self.fit_range = {}
10999
self.hcandeff = {"pr": None, "np": None}
110100
self.hcandeff_gen = {}
@@ -124,6 +114,7 @@ def __init__(self, datap, case, typean, period):
124114
for param, symbol in zip(
125115
("mean", "sigma", "significance", "chi2"),
126116
("#it{#mu}", "#it{#sigma}", "significance", "#it{#chi}^{2}"),
117+
strict=False,
127118
)
128119
}
129120
for level in self.fit_levels
@@ -140,10 +131,8 @@ def __init__(self, datap, case, typean, period):
140131
self.file_out_histo = TFile(self.n_fileresult, "recreate")
141132

142133
self.fitter = RooFitter()
143-
self.roo_ws = {}
144-
self.roo_ws_ptjet = {}
145-
self.roows = {}
146-
self.roows_ptjet = {}
134+
self.roo_ws = {} # ROOT workspaces stored at various levels
135+
self.roows = {} # ROOT workspaces at latest level
147136

148137
# region helpers
149138
def _save_canvas(self, canvas, filename):
@@ -487,17 +476,13 @@ def _fit_mass(self, hist, filename=None):
487476
# pylint: disable=too-many-branches,too-many-statements
488477
def fit(self):
489478
if not self.cfg("hfjet", True):
490-
self.logger.info("Not fitting mass distributions for inclusive jets")
491479
return
492480
self.logger.info("Fitting inclusive mass distributions")
493481
gStyle.SetOptFit(1111)
494482
for level in self.fit_levels:
495483
self.fit_mean[level] = [None] * self.nbins
496484
self.fit_sigma[level] = [None] * self.nbins
497-
self.fit_func_bkg[level] = [None] * self.nbins
498485
self.fit_range[level] = [None] * self.nbins
499-
self.roo_ws[level] = [None] * self.nbins
500-
self.roo_ws_ptjet[level] = [[None] * self.nbins for _ in range(10)]
501486
rfilename = self.n_filemass_mc if "mc" in level else self.n_filemass
502487
fitcfg = None
503488
self.logger.debug("Opening file %s.", rfilename)
@@ -528,23 +513,22 @@ def fit(self):
528513
if h_invmass.GetEntries() < 100: # TODO: reconsider criterion
529514
self.logger.error("Not enough entries to fit %s iptjet %s ipt %d", level, iptjet, ipt)
530515
continue
531-
fit_res, _, func_bkg = self._fit_mass(
516+
fit_res, _, _ = self._fit_mass(
532517
h_invmass, f"fit/h_mass_fitted_{string_range_pthf(range_pthf)}_{level}.png"
533518
)
534519
if fit_res and fit_res.Get() and fit_res.IsValid():
535520
self.fit_mean[level][ipt] = fit_res.Parameter(1)
536521
self.fit_sigma[level][ipt] = fit_res.Parameter(2)
537-
self.fit_func_bkg[level][ipt] = func_bkg
538522
else:
539523
self.logger.error("Fit failed for %s bin %d", level, ipt)
540524
if self.cfg("mass_roofit"):
541525
for entry in self.cfg("mass_roofit", []):
542-
if lvl := entry.get("level"):
543-
if lvl != level:
544-
continue
545-
if ptspec := entry.get("ptrange"):
546-
if ptspec[0] > range_pthf[0] or ptspec[1] < range_pthf[1]:
547-
continue
526+
if (lvl := entry.get("level")) and lvl != level:
527+
continue
528+
if (ptspec := entry.get("ptrange")) and (
529+
ptspec[0] > range_pthf[0] or ptspec[1] < range_pthf[1]
530+
):
531+
continue
548532
fitcfg = entry
549533
break
550534
self.logger.debug("Using fit config for %i: %s", ipt, fitcfg)
@@ -559,7 +543,7 @@ def fit(self):
559543
if h_invmass.GetEntries() < 100: # TODO: reconsider criterion
560544
self.logger.error("Not enough entries to fit %s iptjet %s ipt %d", level, iptjet, ipt)
561545
continue
562-
roows = self.roows.get(ipt) if iptjet is None else self.roows_ptjet.get((iptjet, ipt))
546+
roows = self.roows.get((iptjet, ipt))
563547
if roows is None and level != self.fit_levels[0]:
564548
self.logger.critical(
565549
"missing previous fit result, cannot fit %s iptjet %s ipt %d", level, iptjet, ipt
@@ -596,18 +580,17 @@ def fit(self):
596580
# roo_ws.Print()
597581
# TODO: save snapshot per level
598582
# roo_ws.saveSnapshot(level, None)
599-
if iptjet is not None:
600-
self.logger.debug("Setting roows_ptjet for %s iptjet %s ipt %d", level, iptjet, ipt)
601-
self.roows_ptjet[(iptjet, ipt)] = roo_ws
602-
self.roo_ws_ptjet[level][iptjet][ipt] = roo_ws
603-
else:
604-
self.logger.debug("Setting roows for %s iptjet %s ipt %d", level, iptjet, ipt)
605-
self.roows[ipt] = roo_ws
606-
self.roo_ws[level][ipt] = roo_ws
607-
for jptjet in range(get_nbins(h, 1)):
608-
self.roows_ptjet[(jptjet, ipt)] = roo_ws.Clone()
609-
self.roo_ws_ptjet[level][jptjet][ipt] = roo_ws.Clone()
610-
# TODO: take parameter names from DB
583+
self.logger.info("Setting roows_ptjet for %s iptjet %s ipt %d", level, iptjet, ipt)
584+
self.roows[(iptjet, ipt)] = roo_ws.Clone()
585+
self.roo_ws[(level, iptjet, ipt)] = roo_ws.Clone()
586+
if iptjet is None:
587+
if not fitcfg.get("per_ptjet"):
588+
for jptjet in range(get_nbins(h, 1)):
589+
self.logger.info(
590+
"Overwriting roows_ptjet for %s iptjet %s ipt %d", level, jptjet, ipt
591+
)
592+
self.roows[(jptjet, ipt)] = roo_ws.Clone()
593+
self.roo_ws[(level, jptjet, ipt)] = roo_ws.Clone()
611594
if level in ("data", "mc"):
612595
varname_mean = fitcfg.get("var_mean", self.p_param_names["gauss_mean"])
613596
varname_sigma = fitcfg.get("var_sigma", self.p_param_names["gauss_sigma"])
@@ -626,8 +609,6 @@ def fit(self):
626609
ipt + 1, roo_ws.var(varname_sigma).getError()
627610
)
628611
varname_m = fitcfg.get("var", "m")
629-
if roo_ws.pdf("bkg"):
630-
self.fit_func_bkg[level][ipt] = roo_ws.pdf("bkg").asTF(roo_ws.var(varname_m))
631612
self.fit_range[level][ipt] = (
632613
roo_ws.var(varname_m).getMin("fit"),
633614
roo_ws.var(varname_m).getMax("fit"),
@@ -660,12 +641,12 @@ def _subtract_sideband(self, hist, var, mcordata, ipt):
660641
return None
661642

662643
for entry in self.cfg("sidesub", []):
663-
if level := entry.get("level"):
664-
if level != mcordata:
665-
continue
666-
if ptrange_sel := entry.get("ptrange"):
667-
if ptrange_sel[0] > self.bins_candpt[ipt] or ptrange_sel[1] < self.bins_candpt[ipt + 1]:
668-
continue
644+
if (level := entry.get("level")) and level != mcordata:
645+
continue
646+
if (ptrange_sel := entry.get("ptrange")) and (
647+
ptrange_sel[0] > self.bins_candpt[ipt] or ptrange_sel[1] < self.bins_candpt[ipt + 1]
648+
):
649+
continue
669650
regcfg = entry["regions"]
670651
break
671652
regions = {
@@ -699,7 +680,7 @@ def _subtract_sideband(self, hist, var, mcordata, ipt):
699680

700681
fh = {}
701682
area = {}
702-
var_m = self.roows[ipt].var("m")
683+
var_m = self.roows[(None, ipt)].var("m")
703684
for region in regions:
704685
# project out the mass regions (first axis)
705686
axes = list(range(get_dim(hist)))[1:]
@@ -716,56 +697,37 @@ def _subtract_sideband(self, hist, var, mcordata, ipt):
716697
[fh["sideband_left"], fh["sideband_right"]], f"h_ptjet{label}_sideband_{ipt}_{mcordata}"
717698
)
718699
ensure_sumw2(fh_sideband)
719-
720-
subtract_sidebands = False
721-
if mcordata == "data" and self.cfg("sidesub_per_ptjet"):
722-
self.logger.info("Subtracting sidebands in pt jet bins")
723-
for iptjet in range(get_nbins(fh_subtracted, 0)):
724-
if rws := self.roo_ws_ptjet[mcordata][iptjet][ipt]:
725-
f = rws.pdf("bkg").asTF(self.roo_ws[mcordata][ipt].var("m"))
726-
else:
727-
self.logger.error("Could not retrieve roows for %s-%i-%i", mcordata, iptjet, ipt)
728-
continue
700+
if mcordata == "data":
701+
bins_ptjet = list(range(get_nbins(fh_subtracted, 0))) if self.cfg("sidesub_per_ptjet") else [None]
702+
self.logger.info("Scaling sidebands in ptjet-%s bins: %s using %s", label, bins_ptjet, fh_sideband)
703+
hx = project_hist(fh_sideband, (0,), {}) if get_dim(fh_sideband) > 1 else fh_sideband
704+
for iptjet in bins_ptjet:
705+
if iptjet:
706+
n = hx.GetBinContent(iptjet)
707+
self.logger.info("Need to scale in ptjet %i: %g", iptjet, n)
708+
if n <= 0:
709+
continue
710+
rws = self.roo_ws.get((mcordata, iptjet, ipt))
711+
if not rws:
712+
self.logger.error("Falling back to incl. roows for %s-iptjet%i-ipt%i", mcordata, iptjet, ipt)
713+
rws = self.roo_ws.get((mcordata, None, ipt))
714+
if not rws:
715+
self.logger.critical("Could not retrieve roows for %s-iptjet%i-ipt%i", mcordata, iptjet, ipt)
716+
f = rws.pdf("bkg").asTF(rws.var("m"))
729717
area = {region: f.Integral(*limits[region]) for region in regions}
730-
self.logger.info(
731-
"areas for %s-%s: %g, %g, %g",
732-
mcordata,
733-
ipt,
734-
area["signal"],
735-
area["sideband_left"],
736-
area["sideband_right"],
737-
)
718+
self.logger.info("areas for %s-iptjet%s-ipt%s: %s", mcordata, iptjet, ipt, area)
738719
if (area["sideband_left"] + area["sideband_right"]) > 0.0:
739-
subtract_sidebands = True
740720
areaNormFactor = area["signal"] / (area["sideband_left"] + area["sideband_right"])
741-
# TODO: extend to higher dimensions
742-
for ibin in range(get_nbins(fh_subtracted, 1)):
743-
scale_bin(fh_sideband, areaNormFactor, iptjet + 1, ibin + 1)
744-
else:
745-
for region in regions:
746-
f = self.roo_ws[mcordata][ipt].pdf("bkg").asTF(self.roo_ws[mcordata][ipt].var("m"))
747-
area[region] = f.Integral(*limits[region])
748-
749-
self.logger.info(
750-
"areas for %s-%s: %g, %g, %g",
751-
mcordata,
752-
ipt,
753-
area["signal"],
754-
area["sideband_left"],
755-
area["sideband_right"],
756-
)
757-
758-
if (area["sideband_left"] + area["sideband_right"]) > 0.0:
759-
subtract_sidebands = True
760-
areaNormFactor = area["signal"] / (area["sideband_left"] + area["sideband_right"])
761-
fh_sideband.Scale(areaNormFactor)
762-
721+
# TODO: generalize and extend to higher dimensions
722+
if iptjet is None:
723+
fh_sideband.Scale(areaNormFactor)
724+
else:
725+
for ibin in range(get_nbins(fh_subtracted, 1)):
726+
scale_bin(fh_sideband, areaNormFactor, iptjet + 1, ibin + 1)
727+
fh_subtracted.Add(fh_sideband, -1.0)
763728
self._save_hist(fh_sideband, f"sideband/h_ptjet{label}_sideband_{string_range_pthf(range_pthf)}_{mcordata}.png")
764-
if subtract_sidebands:
765-
fh_subtracted.Add(fh_sideband, -1.0)
766729

767730
self._clip_neg(fh_subtracted)
768-
769731
self._save_hist(
770732
fh_subtracted,
771733
f"sideband/h_ptjet{label}_subtracted_notscaled_{string_range_pthf(range_pthf)}_{mcordata}.png",
@@ -802,17 +764,17 @@ def _subtract_sideband(self, hist, var, mcordata, ipt):
802764
self._save_canvas(c, filename)
803765

804766
# TODO: calculate per ptjet bin
805-
roows = self.roows[ipt]
767+
roows = self.roows[(None, ipt)]
806768
roows.var("mean").setVal(self.fit_mean[mcordata][ipt])
807769
roows.var("sigma_g1").setVal(self.fit_sigma[mcordata][ipt])
808770
var_m.setRange("signal", *limits["signal"])
809771
var_m.setRange("sidel", *limits["sideband_left"])
810772
var_m.setRange("sider", *limits["sideband_right"])
811773
# correct for reflections
812774
if self.cfg("corr_refl") and (mcordata == "data" or not self.cfg("closure.filter_reflections")):
813-
pdf_sig = self.roows[ipt].pdf("sig")
814-
pdf_refl = self.roows[ipt].pdf("refl")
815-
pdf_bkg = self.roows[ipt].pdf("bkg")
775+
pdf_sig = self.roows[(None, ipt)].pdf("sig")
776+
pdf_refl = self.roows[(None, ipt)].pdf("refl")
777+
pdf_bkg = self.roows[(None, ipt)].pdf("bkg")
816778
frac_sig = roows.var("frac").getVal() if mcordata == "data" else 1.0
817779
frac_bkg = 1.0 - frac_sig
818780
fac_sig = frac_sig * (1.0 - roows.var("frac_refl").getVal())
@@ -862,9 +824,9 @@ def _subtract_sideband(self, hist, var, mcordata, ipt):
862824
self.h_reflcorr.SetBinContent(ipt + 1, corr)
863825
fh_subtracted.Scale(corr)
864826

865-
pdf_sig = self.roows[ipt].pdf("sig")
827+
pdf_sig = self.roows[(None, ipt)].pdf("sig")
866828
frac_sig = pdf_sig.createIntegral(var_m, ROOT.RooFit.NormSet(var_m), ROOT.RooFit.Range("signal")).getVal()
867-
if pdf_peak := self.roows[ipt].pdf("peak"):
829+
if pdf_peak := self.roows[(None, ipt)].pdf("peak"):
868830
frac_peak = pdf_peak.createIntegral(var_m, ROOT.RooFit.NormSet(var_m), ROOT.RooFit.Range("signal")).getVal()
869831
self.logger.info(
870832
"correcting %s-%i for fractional signal area: %g (Gaussian: %g)", mcordata, ipt, frac_sig, frac_peak

0 commit comments

Comments
 (0)