Skip to content

Commit 5c86907

Browse files
authored
Merge pull request #142 from fact-project/cta_dl1
Cta dl1 support
2 parents d9cfd50 + 8816c68 commit 5c86907

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+3359
-2234
lines changed

aict_tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.26.0'
1+
__version__ = "0.26.0"

aict_tools/apply.py

Lines changed: 241 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from astropy.table import Table, join, vstack, unique
12
import numpy as np
23
import logging
4+
import os
35
from operator import le, lt, eq, ne, ge, gt
46
import h5py
7+
import tables
8+
from tables import NaturalNameWarning
59
from tqdm import tqdm
10+
import warnings
611

712
from .preprocessing import convert_to_float32, check_valid_rows
813
from .io import get_number_of_rows_in_table
@@ -11,22 +16,28 @@
1116

1217

1318
OPERATORS = {
14-
'<': lt, 'lt': lt,
15-
'<=': le, 'le': le,
16-
'==': eq, 'eq': eq,
17-
'=': eq,
18-
'!=': ne, 'ne': ne,
19-
'>': gt, 'gt': gt,
20-
'>=': ge, 'ge': ge,
19+
"<": lt,
20+
"lt": lt,
21+
"<=": le,
22+
"le": le,
23+
"==": eq,
24+
"eq": eq,
25+
"=": eq,
26+
"!=": ne,
27+
"ne": ne,
28+
">": gt,
29+
"gt": gt,
30+
">=": ge,
31+
"ge": ge,
2132
}
2233

2334
text2symbol = {
24-
'lt': '<',
25-
'le': '<=',
26-
'eq': '==',
27-
'ne': '!=',
28-
'gt': '>',
29-
'ge': '>=',
35+
"lt": "<",
36+
"le": "<=",
37+
"eq": "==",
38+
"ne": "!=",
39+
"gt": ">",
40+
"ge": ">=",
3041
}
3142

3243

@@ -36,10 +47,10 @@ def build_query(selection_config):
3647
o = text2symbol.get(o, o)
3748

3849
queries.append(
39-
'{} {} {}'.format(k, o, '"' + v + '"' if isinstance(v, str) else v)
50+
"{} {} {}".format(k, o, '"' + v + '"' if isinstance(v, str) else v)
4051
)
4152

42-
query = '(' + ') & ('.join(queries) + ')'
53+
query = "(" + ") & (".join(queries) + ")"
4354
return query
4455

4556

@@ -72,17 +83,14 @@ def predict_disp(df, abs_model, sign_model, log_target=False):
7283
return disp_prediction
7384

7485

75-
def predict_dxdy(df, dxdy_model, log_target=False):
86+
def predict_dxdy(df, dxdy_model):
7687
df_features = convert_to_float32(df)
7788
valid = check_valid_rows(df_features)
7889

7990
# 2, because prediction will return two values: dx and dy
8091
dxdy_prediction = np.full((len(df_features), 2), np.nan)
8192
dxdy_prediction[valid] = dxdy_model.predict(df_features.loc[valid].values)
8293

83-
if log_target:
84-
dxdy_prediction[valid] = np.exp(dxdy_prediction[valid])
85-
8694
return dxdy_prediction
8795

8896

@@ -100,33 +108,115 @@ def create_mask_h5py(
100108
infile,
101109
selection_config,
102110
n_events,
103-
key='events',
111+
key="events",
104112
start=None,
105113
end=None,
106114
):
115+
"""
116+
Creates a boolean mask for a dataframe in a h5 file based on a
117+
selection config.
118+
119+
Parameters:
120+
-----------
121+
infile: str, Path
122+
selection_config: dict
123+
Dictionary with column names as keys and (operator, value) tuples as value
124+
n_events: int
125+
Number of events to select.
126+
key: str
127+
Path to the dataframe in the file
128+
start: int
129+
If None, select the first row
130+
end: int
131+
If None, select the last row
132+
133+
Returns:
134+
--------
135+
Boolean mask with len=n_events or len(df)
136+
"""
107137
start = start or 0
108138
end = min(n_events, end) if end else n_events
109139

110140
n_selected = end - start
111141
mask = np.ones(n_selected, dtype=bool)
112142

113-
# legacy support for dict of column_name -> [op, val]
114143
if isinstance(selection_config, dict):
115-
selection_config = [{k: v} for k, v in selection_config.items()]
144+
raise ValueError("Dictionaries are not supported for the cuts anymore, use a list")
116145

117146
for c in selection_config:
118147
if len(c) > 1:
119-
raise ValueError('Expected dict with single entry column: [operator, value].')
148+
raise ValueError(
149+
"Expected dict with single entry column: [operator, value]."
150+
)
120151
name, (operator, value) = list(c.items())[0]
121152

122-
before = mask.sum()
123-
mask = np.logical_and(
124-
mask, OPERATORS[operator](infile[key][name][start:end], value)
153+
before = np.count_nonzero(mask)
154+
selection = OPERATORS[operator](infile[key][name][start:end], value)
155+
mask = np.logical_and(mask, selection)
156+
after = np.count_nonzero(mask)
157+
log.debug(
158+
'Cut "{} {} {}" removed {} events'.format(
159+
name, operator, value, before - after
160+
)
161+
)
162+
163+
return mask
164+
165+
166+
def create_mask_table(
167+
table,
168+
selection_config,
169+
n_events,
170+
start=None,
171+
end=None,
172+
):
173+
"""
174+
Creates a boolean mask for a pytables.Table object
175+
176+
Parameters:
177+
-----------
178+
table: pytables.Table
179+
Table to perform selection on
180+
selection_config: dict
181+
Dictionary with column names as keys and (operator, value) tuples as value
182+
n_events: int
183+
Number of events to select.
184+
start: int
185+
If None, select the first row
186+
end: int
187+
If None, select the last row
188+
189+
Returns:
190+
--------
191+
Boolean mask with len n_events or len(table) if unspecified
192+
"""
193+
start = start or 0
194+
end = min(n_events, end) if end else n_events
195+
196+
n_selected = end - start
197+
mask = np.ones(n_selected, dtype=bool)
198+
199+
for c in selection_config:
200+
if len(c) > 1:
201+
raise ValueError(
202+
"Expected dict with single entry column: [operator, value]."
203+
)
204+
name, (operator, value) = list(c.items())[0]
205+
206+
before = np.count_nonzero(mask)
207+
if name not in table.colnames:
208+
raise KeyError(
209+
f"Cant perform selection based on {name} "
210+
"Column is missing from parameters table"
211+
)
212+
selection = OPERATORS[operator](table.col(name)[start:end], value)
213+
mask = np.logical_and(mask, selection)
214+
after = np.count_nonzero(mask)
215+
log.debug(
216+
'Cut "{} {} {}" removed {} events'.format(
217+
name, operator, value, before - after
218+
)
125219
)
126-
after = mask.sum()
127-
log.debug('Cut "{} {} {}" removed {} events'.format(
128-
name, operator, value, before - after
129-
))
130220

131221
return mask
132222

@@ -135,20 +225,23 @@ def apply_cuts_h5py_chunked(
135225
input_path,
136226
output_path,
137227
selection_config,
138-
key='events',
228+
key="events",
139229
chunksize=100000,
140230
progress=True,
141231
):
142-
'''
232+
"""
143233
Apply cuts defined in selection config to input_path and write result to
144234
outputpath. Apply cuts to chunksize events at a time.
145-
'''
235+
"""
146236

147-
n_events = get_number_of_rows_in_table(input_path, key=key, )
237+
n_events = get_number_of_rows_in_table(
238+
input_path,
239+
key=key,
240+
)
148241
n_chunks = int(np.ceil(n_events / chunksize))
149-
log.debug('Using {} chunks of size {}'.format(n_chunks, chunksize))
242+
log.debug("Using {} chunks of size {}".format(n_chunks, chunksize))
150243

151-
with h5py.File(input_path, 'r') as infile, h5py.File(output_path, 'w') as outfile:
244+
with h5py.File(input_path, "r") as infile, h5py.File(output_path, "w") as outfile:
152245
group = outfile.create_group(key)
153246

154247
for chunk in tqdm(range(n_chunks), disable=not progress, total=n_chunks):
@@ -158,29 +251,133 @@ def apply_cuts_h5py_chunked(
158251
mask = create_mask_h5py(
159252
infile, selection_config, n_events, key=key, start=start, end=end
160253
)
161-
162254
for name, dataset in infile[key].items():
163255
if chunk == 0:
164256
if dataset.ndim == 1:
165257
group.create_dataset(
166-
name, data=dataset[start:end][mask], maxshape=(None, )
258+
name, data=dataset[start:end][mask], maxshape=(None,)
167259
)
168260
elif dataset.ndim == 2:
169261
group.create_dataset(
170-
name, data=dataset[start:end, :][mask, :], maxshape=(None, 2)
262+
name,
263+
data=dataset[start:end, :][mask, :],
264+
maxshape=(None, 2),
171265
)
172266
else:
173-
log.warning('Skipping not 1d or 2d column {}'.format(name))
267+
log.warning("Skipping not 1d or 2d column {}".format(name))
174268

175269
else:
176270

177271
n_old = group[name].shape[0]
178-
n_new = mask.sum()
272+
n_new = np.count_nonzero(mask)
179273
group[name].resize(n_old + n_new, axis=0)
180274

181275
if dataset.ndim == 1:
182-
group[name][n_old:n_old + n_new] = dataset[start:end][mask]
276+
group[name][n_old : n_old + n_new] = dataset[start:end][mask]
183277
elif dataset.ndim == 2:
184-
group[name][n_old:n_old + n_new, :] = dataset[start:end][mask, :]
278+
group[name][n_old : n_old + n_new, :] = dataset[start:end][
279+
mask, :
280+
]
185281
else:
186-
log.warning('Skipping not 1d or 2d column {}'.format(name))
282+
log.warning("Skipping not 1d or 2d column {}".format(name))
283+
284+
285+
def apply_cuts_cta_dl1(
286+
input_path,
287+
output_path,
288+
selection_config,
289+
keep_images=True,
290+
):
291+
"""
292+
Apply cuts from a selection config to a cta dl1 file and write results
293+
to output_path.
294+
"""
295+
filters = tables.Filters(
296+
complevel=5, # compression medium, tradeoff between speed and compression
297+
complib="blosc:zstd", # use modern zstd algorithm
298+
fletcher32=True, # add checksums to data chunks
299+
)
300+
n_rows_before = 0
301+
n_rows_after = 0
302+
303+
with tables.open_file(input_path) as in_, tables.open_file(
304+
output_path, "w", filters=filters
305+
) as out_:
306+
# perform cuts on the measured parameters
307+
remaining_showers = []
308+
for table in in_.root.dl1.event.telescope.parameters:
309+
key = "/dl1/event/telescope/parameters"
310+
mask = create_mask_table(
311+
table,
312+
selection_config,
313+
n_events=len(table),
314+
)
315+
new_table = out_.create_table(
316+
key,
317+
table.name,
318+
table.description,
319+
createparents=True,
320+
expectedrows=np.count_nonzero(mask),
321+
)
322+
# set user attributes
323+
for name in table.attrs._f_list():
324+
new_table.attrs[name] = table.attrs[name]
325+
surviving_events = table.read()
326+
new_table.append(surviving_events[mask])
327+
remaining_showers.append(
328+
Table(
329+
data=surviving_events[mask][["obs_id", "event_id"]],
330+
names=["obs_id", "event_id"],
331+
)
332+
)
333+
n_rows_before += len(table)
334+
n_rows_after += np.count_nonzero(mask)
335+
selection_table = unique(vstack(remaining_showers))
336+
for node in in_.walk_nodes():
337+
nodepath = node._v_parent._v_pathname
338+
# skip groups, we create them later
339+
if isinstance(node, tables.group.Group):
340+
continue
341+
# parameter tables were already processed
342+
elif nodepath == "/dl1/event/telescope/parameters":
343+
continue
344+
if not keep_images:
345+
if nodepath == "/dl1/event/telescope/images":
346+
continue
347+
elif nodepath == "/simulation/event/telescope/images":
348+
continue
349+
# tables with event_id always contain the obs_id as well
350+
if isinstance(node, tables.Table) and "event_id" in node.colnames:
351+
selected = join(
352+
selection_table,
353+
node.read(),
354+
keys=["obs_id", "event_id"],
355+
join_type="left",
356+
)
357+
new_table = out_.create_table(
358+
nodepath,
359+
node.name,
360+
node.description,
361+
createparents=True,
362+
expectedrows=len(selected),
363+
)
364+
# set user attributes
365+
for name in node.attrs._f_list():
366+
new_table.attrs[name] = node.attrs[name]
367+
new_table.append(selected.as_array().astype(node.dtype))
368+
else:
369+
if nodepath not in out_:
370+
head, tail = os.path.split(nodepath)
371+
out_.create_group(head, tail, createparents=True)
372+
with warnings.catch_warnings():
373+
warnings.simplefilter("ignore", NaturalNameWarning)
374+
new = in_.copy_node(
375+
node,
376+
newparent=out_.root[nodepath],
377+
)
378+
# set root attributes
379+
with warnings.catch_warnings():
380+
warnings.simplefilter("ignore", NaturalNameWarning)
381+
for name in in_.root._v_attrs._f_list():
382+
out_.root._v_attrs[name] = in_.root._v_attrs[name]
383+
return n_rows_before, n_rows_after

0 commit comments

Comments
 (0)