1+ from astropy .table import Table , join , vstack , unique
12import numpy as np
23import logging
4+ import os
35from operator import le , lt , eq , ne , ge , gt
46import h5py
7+ import tables
8+ from tables import NaturalNameWarning
59from tqdm import tqdm
10+ import warnings
611
712from .preprocessing import convert_to_float32 , check_valid_rows
813from .io import get_number_of_rows_in_table
1116
1217
1318OPERATORS = {
14- '<' : lt , 'lt' : lt ,
15- '<=' : le , 'le' : le ,
16- '==' : eq , 'eq' : eq ,
17- '=' : eq ,
18- '!=' : ne , 'ne' : ne ,
19- '>' : gt , 'gt' : gt ,
20- '>=' : ge , 'ge' : ge ,
19+ "<" : lt ,
20+ "lt" : lt ,
21+ "<=" : le ,
22+ "le" : le ,
23+ "==" : eq ,
24+ "eq" : eq ,
25+ "=" : eq ,
26+ "!=" : ne ,
27+ "ne" : ne ,
28+ ">" : gt ,
29+ "gt" : gt ,
30+ ">=" : ge ,
31+ "ge" : ge ,
2132}
2233
2334text2symbol = {
24- 'lt' : '<' ,
25- 'le' : '<=' ,
26- 'eq' : '==' ,
27- 'ne' : '!=' ,
28- 'gt' : '>' ,
29- 'ge' : '>=' ,
35+ "lt" : "<" ,
36+ "le" : "<=" ,
37+ "eq" : "==" ,
38+ "ne" : "!=" ,
39+ "gt" : ">" ,
40+ "ge" : ">=" ,
3041}
3142
3243
@@ -36,10 +47,10 @@ def build_query(selection_config):
3647 o = text2symbol .get (o , o )
3748
3849 queries .append (
39- ' {} {} {}' .format (k , o , '"' + v + '"' if isinstance (v , str ) else v )
50+ " {} {} {}" .format (k , o , '"' + v + '"' if isinstance (v , str ) else v )
4051 )
4152
42- query = '(' + ' ) & (' .join (queries ) + ')'
53+ query = "(" + " ) & (" .join (queries ) + ")"
4354 return query
4455
4556
@@ -72,17 +83,14 @@ def predict_disp(df, abs_model, sign_model, log_target=False):
7283 return disp_prediction
7384
7485
75- def predict_dxdy (df , dxdy_model , log_target = False ):
86+ def predict_dxdy (df , dxdy_model ):
7687 df_features = convert_to_float32 (df )
7788 valid = check_valid_rows (df_features )
7889
7990 # 2, because prediction will return two values: dx and dy
8091 dxdy_prediction = np .full ((len (df_features ), 2 ), np .nan )
8192 dxdy_prediction [valid ] = dxdy_model .predict (df_features .loc [valid ].values )
8293
83- if log_target :
84- dxdy_prediction [valid ] = np .exp (dxdy_prediction [valid ])
85-
8694 return dxdy_prediction
8795
8896
@@ -100,33 +108,115 @@ def create_mask_h5py(
100108 infile ,
101109 selection_config ,
102110 n_events ,
103- key = ' events' ,
111+ key = " events" ,
104112 start = None ,
105113 end = None ,
106114):
115+ """
116+ Creates a boolean mask for a dataframe in a h5 file based on a
117+ selection config.
118+
119+ Parameters:
120+ -----------
121+ infile: str, Path
122+ selection_config: dict
123+ Dictionary with column names as keys and (operator, value) tuples as value
124+ n_events: int
125+ Number of events to select.
126+ key: str
127+ Path to the dataframe in the file
128+ start: int
129+ If None, select the first row
130+ end: int
131+ If None, select the last row
132+
133+ Returns:
134+ --------
135+ Boolean mask with len=n_events or len(df)
136+ """
107137 start = start or 0
108138 end = min (n_events , end ) if end else n_events
109139
110140 n_selected = end - start
111141 mask = np .ones (n_selected , dtype = bool )
112142
113- # legacy support for dict of column_name -> [op, val]
114143 if isinstance (selection_config , dict ):
115- selection_config = [{ k : v } for k , v in selection_config . items ()]
144+ raise ValueError ( "Dictionaries are not supported for the cuts anymore, use a list" )
116145
117146 for c in selection_config :
118147 if len (c ) > 1 :
119- raise ValueError ('Expected dict with single entry column: [operator, value].' )
148+ raise ValueError (
149+ "Expected dict with single entry column: [operator, value]."
150+ )
120151 name , (operator , value ) = list (c .items ())[0 ]
121152
122- before = mask .sum ()
123- mask = np .logical_and (
124- mask , OPERATORS [operator ](infile [key ][name ][start :end ], value )
153+ before = np .count_nonzero (mask )
154+ selection = OPERATORS [operator ](infile [key ][name ][start :end ], value )
155+ mask = np .logical_and (mask , selection )
156+ after = np .count_nonzero (mask )
157+ log .debug (
158+ 'Cut "{} {} {}" removed {} events' .format (
159+ name , operator , value , before - after
160+ )
161+ )
162+
163+ return mask
164+
165+
166+ def create_mask_table (
167+ table ,
168+ selection_config ,
169+ n_events ,
170+ start = None ,
171+ end = None ,
172+ ):
173+ """
174+ Creates a boolean mask for a pytables.Table object
175+
176+ Parameters:
177+ -----------
178+ table: pytables.Table
179+ Table to perform selection on
180+ selection_config: dict
181+ Dictionary with column names as keys and (operator, value) tuples as value
182+ n_events: int
183+ Number of events to select.
184+ start: int
185+ If None, select the first row
186+ end: int
187+ If None, select the last row
188+
189+ Returns:
190+ --------
191+ Boolean mask with len n_events or len(table) if unspecified
192+ """
193+ start = start or 0
194+ end = min (n_events , end ) if end else n_events
195+
196+ n_selected = end - start
197+ mask = np .ones (n_selected , dtype = bool )
198+
199+ for c in selection_config :
200+ if len (c ) > 1 :
201+ raise ValueError (
202+ "Expected dict with single entry column: [operator, value]."
203+ )
204+ name , (operator , value ) = list (c .items ())[0 ]
205+
206+ before = np .count_nonzero (mask )
207+ if name not in table .colnames :
208+ raise KeyError (
209+ f"Cant perform selection based on { name } "
210+ "Column is missing from parameters table"
211+ )
212+ selection = OPERATORS [operator ](table .col (name )[start :end ], value )
213+ mask = np .logical_and (mask , selection )
214+ after = np .count_nonzero (mask )
215+ log .debug (
216+ 'Cut "{} {} {}" removed {} events' .format (
217+ name , operator , value , before - after
218+ )
125219 )
126- after = mask .sum ()
127- log .debug ('Cut "{} {} {}" removed {} events' .format (
128- name , operator , value , before - after
129- ))
130220
131221 return mask
132222
@@ -135,20 +225,23 @@ def apply_cuts_h5py_chunked(
135225 input_path ,
136226 output_path ,
137227 selection_config ,
138- key = ' events' ,
228+ key = " events" ,
139229 chunksize = 100000 ,
140230 progress = True ,
141231):
142- '''
232+ """
143233 Apply cuts defined in selection config to input_path and write result to
144234 outputpath. Apply cuts to chunksize events at a time.
145- '''
235+ """
146236
147- n_events = get_number_of_rows_in_table (input_path , key = key , )
237+ n_events = get_number_of_rows_in_table (
238+ input_path ,
239+ key = key ,
240+ )
148241 n_chunks = int (np .ceil (n_events / chunksize ))
149- log .debug (' Using {} chunks of size {}' .format (n_chunks , chunksize ))
242+ log .debug (" Using {} chunks of size {}" .format (n_chunks , chunksize ))
150243
151- with h5py .File (input_path , 'r' ) as infile , h5py .File (output_path , 'w' ) as outfile :
244+ with h5py .File (input_path , "r" ) as infile , h5py .File (output_path , "w" ) as outfile :
152245 group = outfile .create_group (key )
153246
154247 for chunk in tqdm (range (n_chunks ), disable = not progress , total = n_chunks ):
@@ -158,29 +251,133 @@ def apply_cuts_h5py_chunked(
158251 mask = create_mask_h5py (
159252 infile , selection_config , n_events , key = key , start = start , end = end
160253 )
161-
162254 for name , dataset in infile [key ].items ():
163255 if chunk == 0 :
164256 if dataset .ndim == 1 :
165257 group .create_dataset (
166- name , data = dataset [start :end ][mask ], maxshape = (None , )
258+ name , data = dataset [start :end ][mask ], maxshape = (None ,)
167259 )
168260 elif dataset .ndim == 2 :
169261 group .create_dataset (
170- name , data = dataset [start :end , :][mask , :], maxshape = (None , 2 )
262+ name ,
263+ data = dataset [start :end , :][mask , :],
264+ maxshape = (None , 2 ),
171265 )
172266 else :
173- log .warning (' Skipping not 1d or 2d column {}' .format (name ))
267+ log .warning (" Skipping not 1d or 2d column {}" .format (name ))
174268
175269 else :
176270
177271 n_old = group [name ].shape [0 ]
178- n_new = mask . sum ( )
272+ n_new = np . count_nonzero ( mask )
179273 group [name ].resize (n_old + n_new , axis = 0 )
180274
181275 if dataset .ndim == 1 :
182- group [name ][n_old : n_old + n_new ] = dataset [start :end ][mask ]
276+ group [name ][n_old : n_old + n_new ] = dataset [start :end ][mask ]
183277 elif dataset .ndim == 2 :
184- group [name ][n_old :n_old + n_new , :] = dataset [start :end ][mask , :]
278+ group [name ][n_old : n_old + n_new , :] = dataset [start :end ][
279+ mask , :
280+ ]
185281 else :
186- log .warning ('Skipping not 1d or 2d column {}' .format (name ))
282+ log .warning ("Skipping not 1d or 2d column {}" .format (name ))
283+
284+
285+ def apply_cuts_cta_dl1 (
286+ input_path ,
287+ output_path ,
288+ selection_config ,
289+ keep_images = True ,
290+ ):
291+ """
292+ Apply cuts from a selection config to a cta dl1 file and write results
293+ to output_path.
294+ """
295+ filters = tables .Filters (
296+ complevel = 5 , # compression medium, tradeoff between speed and compression
297+ complib = "blosc:zstd" , # use modern zstd algorithm
298+ fletcher32 = True , # add checksums to data chunks
299+ )
300+ n_rows_before = 0
301+ n_rows_after = 0
302+
303+ with tables .open_file (input_path ) as in_ , tables .open_file (
304+ output_path , "w" , filters = filters
305+ ) as out_ :
306+ # perform cuts on the measured parameters
307+ remaining_showers = []
308+ for table in in_ .root .dl1 .event .telescope .parameters :
309+ key = "/dl1/event/telescope/parameters"
310+ mask = create_mask_table (
311+ table ,
312+ selection_config ,
313+ n_events = len (table ),
314+ )
315+ new_table = out_ .create_table (
316+ key ,
317+ table .name ,
318+ table .description ,
319+ createparents = True ,
320+ expectedrows = np .count_nonzero (mask ),
321+ )
322+ # set user attributes
323+ for name in table .attrs ._f_list ():
324+ new_table .attrs [name ] = table .attrs [name ]
325+ surviving_events = table .read ()
326+ new_table .append (surviving_events [mask ])
327+ remaining_showers .append (
328+ Table (
329+ data = surviving_events [mask ][["obs_id" , "event_id" ]],
330+ names = ["obs_id" , "event_id" ],
331+ )
332+ )
333+ n_rows_before += len (table )
334+ n_rows_after += np .count_nonzero (mask )
335+ selection_table = unique (vstack (remaining_showers ))
336+ for node in in_ .walk_nodes ():
337+ nodepath = node ._v_parent ._v_pathname
338+ # skip groups, we create them later
339+ if isinstance (node , tables .group .Group ):
340+ continue
341+ # parameter tables were already processed
342+ elif nodepath == "/dl1/event/telescope/parameters" :
343+ continue
344+ if not keep_images :
345+ if nodepath == "/dl1/event/telescope/images" :
346+ continue
347+ elif nodepath == "/simulation/event/telescope/images" :
348+ continue
349+ # tables with event_id always contain the obs_id as well
350+ if isinstance (node , tables .Table ) and "event_id" in node .colnames :
351+ selected = join (
352+ selection_table ,
353+ node .read (),
354+ keys = ["obs_id" , "event_id" ],
355+ join_type = "left" ,
356+ )
357+ new_table = out_ .create_table (
358+ nodepath ,
359+ node .name ,
360+ node .description ,
361+ createparents = True ,
362+ expectedrows = len (selected ),
363+ )
364+ # set user attributes
365+ for name in node .attrs ._f_list ():
366+ new_table .attrs [name ] = node .attrs [name ]
367+ new_table .append (selected .as_array ().astype (node .dtype ))
368+ else :
369+ if nodepath not in out_ :
370+ head , tail = os .path .split (nodepath )
371+ out_ .create_group (head , tail , createparents = True )
372+ with warnings .catch_warnings ():
373+ warnings .simplefilter ("ignore" , NaturalNameWarning )
374+ new = in_ .copy_node (
375+ node ,
376+ newparent = out_ .root [nodepath ],
377+ )
378+ # set root attributes
379+ with warnings .catch_warnings ():
380+ warnings .simplefilter ("ignore" , NaturalNameWarning )
381+ for name in in_ .root ._v_attrs ._f_list ():
382+ out_ .root ._v_attrs [name ] = in_ .root ._v_attrs [name ]
383+ return n_rows_before , n_rows_after
0 commit comments