@@ -69,7 +69,7 @@ def list_intersection(list1, list2):
69
69
70
70
71
71
@dataclass
72
- class LoadedSample (ABC ):
72
+ class LoadedSampleABC (ABC ):
73
73
"""Abstract base class for loaded samples.
74
74
75
75
This class defines the interface for accessing variables in a loaded sample.
@@ -94,6 +94,19 @@ def get_var(self, feat: str) -> np.ndarray:
94
94
The variable as a numpy array
95
95
"""
96
96
97
+ @abstractmethod
98
+ def copy_from_selection (
99
+ self , selection : np .ndarray [bool ], do_deepcopy : bool = False
100
+ ) -> LoadedSampleABC :
101
+ """Copy the events from a selection.
102
+
103
+ Args:
104
+ selection: boolean mask
105
+
106
+ Returns:
107
+ A new LoadedSampleABC object with the selected events
108
+ """
109
+
97
110
def apply_selection (self , selection : np .ndarray [bool ]):
98
111
"""Apply a selection to the events.
99
112
@@ -252,12 +265,15 @@ class Cutflow:
252
265
cutflow : pd .DataFrame = None
253
266
254
267
def __post_init__ (self ):
255
- self .sample_labels = [s .label for s in self .samples ]
268
+ if "LoadedSample" in str (type (next (iter (self .samples .values ())))):
269
+ self .samples = {skey : s .sample for skey , s in self .samples .items ()}
270
+
271
+ self .sample_labels = [s .label for s in self .samples .values ()]
256
272
257
273
if self .cutflow is None :
258
274
self .cutflow = pd .DataFrame (index = self .sample_labels )
259
275
260
- def add_cut (self , events_dict : dict [str , LoadedSample ], cut_key : str , weight_key : str ):
276
+ def add_cut (self , events_dict : dict [str , LoadedSampleABC ], cut_key : str , weight_key : str ):
261
277
"""Add a cut to the cutflow.
262
278
263
279
Args:
@@ -272,6 +288,9 @@ def add_cut(self, events_dict: dict[str, LoadedSample], cut_key: str, weight_key
272
288
def concat (self , other : dict ):
273
289
self .cutflow = pd .concat ((self .cutflow , other ), axis = 1 )
274
290
291
+ def to_csv (self , path : Path ):
292
+ self .cutflow .to_csv (path )
293
+
275
294
276
295
@contextlib .contextmanager
277
296
def timer ():
@@ -824,7 +843,7 @@ def blindBins(h: Hist, blind_region: list, blind_sample: str | None = None, axis
824
843
825
844
826
845
def singleVarHist (
827
- events_dict : dict [str , LoadedSample ],
846
+ events_dict : dict [str , LoadedSampleABC ],
828
847
shape_var : ShapeVar ,
829
848
weight_key : str = "finalWeight" ,
830
849
selection : dict | None = None ,
@@ -929,7 +948,7 @@ def add_selection(
929
948
sel : np .ndarray [bool ],
930
949
selection ,
931
950
cutflow : dict ,
932
- sample : LoadedSample ,
951
+ sample : LoadedSampleABC ,
933
952
weight_key : str ,
934
953
):
935
954
"""Adds selection to PackedSelection object and the cutflow"""
@@ -961,7 +980,7 @@ def var_mapping(var):
961
980
962
981
963
982
def _var_selection (
964
- sample : LoadedSample ,
983
+ sample : LoadedSampleABC ,
965
984
var : str ,
966
985
brange : list [float ],
967
986
jshift : str ,
@@ -1006,7 +1025,7 @@ def _var_selection(
1006
1025
1007
1026
def make_selection (
1008
1027
var_cuts : dict [str , list [float ]],
1009
- events_dict : dict [str , LoadedSample ],
1028
+ events_dict : dict [str , LoadedSampleABC ],
1010
1029
weight_key : str = "finalWeight" ,
1011
1030
prev_cutflow : Cutflow = None ,
1012
1031
selection : dict [str , np .ndarray ] = None ,
@@ -1104,14 +1123,12 @@ def make_selection(
1104
1123
selection [skey ] = selection [skey ].all (* selection [skey ].names )
1105
1124
1106
1125
cutflow = pd .DataFrame .from_dict (list (cutflow .values ()))
1107
- cutflow .index = [s .label for s in events_dict .values ()]
1126
+ cutflow .index = [s .sample . label for s in events_dict .values ()]
1108
1127
1109
1128
if prev_cutflow is not None :
1110
1129
cutflow = prev_cutflow .concat (cutflow )
1111
1130
else :
1112
- cutflow = Cutflow (
1113
- samples = {skey : sample .sample for skey , sample in events_dict .items ()}, cutflow = cutflow
1114
- )
1131
+ cutflow = Cutflow (samples = events_dict , cutflow = cutflow )
1115
1132
1116
1133
return selection , cutflow
1117
1134
@@ -1258,3 +1275,34 @@ def combine_hbb_bgs(hists, systs: list[str] = ()):
1258
1275
h .view ()[nsamples + 1 + i ] = hbbhist .view ()
1259
1276
1260
1277
return h
1278
+
1279
+
1280
+ def get_fill_data (events : LoadedSampleABC , shape_vars : list [ShapeVar ], jshift : str = "" ):
1281
+ return {
1282
+ shape_var .var : events .get_var (
1283
+ shape_var .var if jshift == "" else check_get_jec_var (shape_var .var , jshift ),
1284
+ )
1285
+ for shape_var in shape_vars
1286
+ }
1287
+
1288
+
1289
+ def get_qcdvar_hists (
1290
+ sample : LoadedSampleABC , shape_vars : list [ShapeVar ], fill_data : dict , wshift : str
1291
+ ):
1292
+ """Get histograms for QCD scale and PDF variations"""
1293
+ wkey = f"{ wshift } _weights"
1294
+ cols = list (sample .events [wkey ].columns )
1295
+ h = Hist (
1296
+ hist .axis .StrCategory ([str (i ) for i in cols ], name = "Sample" ),
1297
+ * [shape_var .axis for shape_var in shape_vars ],
1298
+ storage = "weight" ,
1299
+ )
1300
+
1301
+ for i in cols :
1302
+ h .fill (
1303
+ Sample = str (i ),
1304
+ ** fill_data ,
1305
+ weight = sample .events [wkey ][i ].to_numpy ().squeeze (),
1306
+ )
1307
+
1308
+ return h
0 commit comments