Skip to content

Commit 6787c43

Browse files
committed
more utilities
1 parent d2f48db commit 6787c43

File tree

1 file changed

+59
-11
lines changed

1 file changed

+59
-11
lines changed

src/boostedhh/utils.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def list_intersection(list1, list2):
6969

7070

7171
@dataclass
72-
class LoadedSample(ABC):
72+
class LoadedSampleABC(ABC):
7373
"""Abstract base class for loaded samples.
7474
7575
This class defines the interface for accessing variables in a loaded sample.
@@ -94,6 +94,19 @@ def get_var(self, feat: str) -> np.ndarray:
9494
The variable as a numpy array
9595
"""
9696

97+
@abstractmethod
98+
def copy_from_selection(
99+
self, selection: np.ndarray[bool], do_deepcopy: bool = False
100+
) -> LoadedSampleABC:
101+
"""Copy the events from a selection.
102+
103+
Args:
104+
selection: boolean mask
105+
106+
Returns:
107+
A new LoadedSampleABC object with the selected events
108+
"""
109+
97110
def apply_selection(self, selection: np.ndarray[bool]):
98111
"""Apply a selection to the events.
99112
@@ -252,12 +265,15 @@ class Cutflow:
252265
cutflow: pd.DataFrame = None
253266

254267
def __post_init__(self):
255-
self.sample_labels = [s.label for s in self.samples]
268+
if "LoadedSample" in str(type(next(iter(self.samples.values())))):
269+
self.samples = {skey: s.sample for skey, s in self.samples.items()}
270+
271+
self.sample_labels = [s.label for s in self.samples.values()]
256272

257273
if self.cutflow is None:
258274
self.cutflow = pd.DataFrame(index=self.sample_labels)
259275

260-
def add_cut(self, events_dict: dict[str, LoadedSample], cut_key: str, weight_key: str):
276+
def add_cut(self, events_dict: dict[str, LoadedSampleABC], cut_key: str, weight_key: str):
261277
"""Add a cut to the cutflow.
262278
263279
Args:
@@ -272,6 +288,9 @@ def add_cut(self, events_dict: dict[str, LoadedSample], cut_key: str, weight_key
272288
def concat(self, other: dict):
273289
self.cutflow = pd.concat((self.cutflow, other), axis=1)
274290

291+
def to_csv(self, path: Path):
292+
self.cutflow.to_csv(path)
293+
275294

276295
@contextlib.contextmanager
277296
def timer():
@@ -824,7 +843,7 @@ def blindBins(h: Hist, blind_region: list, blind_sample: str | None = None, axis
824843

825844

826845
def singleVarHist(
827-
events_dict: dict[str, LoadedSample],
846+
events_dict: dict[str, LoadedSampleABC],
828847
shape_var: ShapeVar,
829848
weight_key: str = "finalWeight",
830849
selection: dict | None = None,
@@ -929,7 +948,7 @@ def add_selection(
929948
sel: np.ndarray[bool],
930949
selection,
931950
cutflow: dict,
932-
sample: LoadedSample,
951+
sample: LoadedSampleABC,
933952
weight_key: str,
934953
):
935954
"""Adds selection to PackedSelection object and the cutflow"""
@@ -961,7 +980,7 @@ def var_mapping(var):
961980

962981

963982
def _var_selection(
964-
sample: LoadedSample,
983+
sample: LoadedSampleABC,
965984
var: str,
966985
brange: list[float],
967986
jshift: str,
@@ -1006,7 +1025,7 @@ def _var_selection(
10061025

10071026
def make_selection(
10081027
var_cuts: dict[str, list[float]],
1009-
events_dict: dict[str, LoadedSample],
1028+
events_dict: dict[str, LoadedSampleABC],
10101029
weight_key: str = "finalWeight",
10111030
prev_cutflow: Cutflow = None,
10121031
selection: dict[str, np.ndarray] = None,
@@ -1104,14 +1123,12 @@ def make_selection(
11041123
selection[skey] = selection[skey].all(*selection[skey].names)
11051124

11061125
cutflow = pd.DataFrame.from_dict(list(cutflow.values()))
1107-
cutflow.index = [s.label for s in events_dict.values()]
1126+
cutflow.index = [s.sample.label for s in events_dict.values()]
11081127

11091128
if prev_cutflow is not None:
11101129
cutflow = prev_cutflow.concat(cutflow)
11111130
else:
1112-
cutflow = Cutflow(
1113-
samples={skey: sample.sample for skey, sample in events_dict.items()}, cutflow=cutflow
1114-
)
1131+
cutflow = Cutflow(samples=events_dict, cutflow=cutflow)
11151132

11161133
return selection, cutflow
11171134

@@ -1258,3 +1275,34 @@ def combine_hbb_bgs(hists, systs: list[str] = ()):
12581275
h.view()[nsamples + 1 + i] = hbbhist.view()
12591276

12601277
return h
1278+
1279+
1280+
def get_fill_data(events: LoadedSampleABC, shape_vars: list[ShapeVar], jshift: str = ""):
1281+
return {
1282+
shape_var.var: events.get_var(
1283+
shape_var.var if jshift == "" else check_get_jec_var(shape_var.var, jshift),
1284+
)
1285+
for shape_var in shape_vars
1286+
}
1287+
1288+
1289+
def get_qcdvar_hists(
1290+
sample: LoadedSampleABC, shape_vars: list[ShapeVar], fill_data: dict, wshift: str
1291+
):
1292+
"""Get histograms for QCD scale and PDF variations"""
1293+
wkey = f"{wshift}_weights"
1294+
cols = list(sample.events[wkey].columns)
1295+
h = Hist(
1296+
hist.axis.StrCategory([str(i) for i in cols], name="Sample"),
1297+
*[shape_var.axis for shape_var in shape_vars],
1298+
storage="weight",
1299+
)
1300+
1301+
for i in cols:
1302+
h.fill(
1303+
Sample=str(i),
1304+
**fill_data,
1305+
weight=sample.events[wkey][i].to_numpy().squeeze(),
1306+
)
1307+
1308+
return h

0 commit comments

Comments
 (0)