-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathfilters.py
More file actions
523 lines (433 loc) · 17.1 KB
/
filters.py
File metadata and controls
523 lines (433 loc) · 17.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
# -*- coding: utf-8 -*-
"""
Filters for NNPDF fits
"""
import logging
import re
from collections.abc import Mapping
from importlib.resources import read_text
import numpy as np
from NNPDF import CommonData, RandomGenerator
from reportengine.checks import make_argcheck, check, check_positive, make_check
from reportengine.compat import yaml
import validphys.cuts
log = logging.getLogger(__name__)
class RuleProcessingError(Exception):
"""Exception raised when we couldn't process a rule."""
class BadPerturbativeOrder(ValueError):
"""Exception raised when the perturbative order string is not
recognized."""
class MissingRuleAttribute(RuleProcessingError, AttributeError):
"""Exception raised when a rule is missing required attributes."""
class FatalRuleError(Exception):
"""Exception raised when a rule application failed at runtime."""
def default_filter_settings_input():
"""Return a dictionary with the default hardcoded filter settings.
These are defined in ``defaults.yaml`` in the ``validphys.cuts`` module.
"""
return yaml.safe_load(read_text(validphys.cuts, "defaults.yaml"))
def default_filter_rules_input():
"""Return a dictionary with the input settings.
These are defined in ``filters.yaml`` in the ``validphys.cuts`` module.
"""
return yaml.safe_load(read_text(validphys.cuts, "filters.yaml"))
@make_argcheck
def check_rngalgo(rngalgo: int):
"""Check rngalgo content"""
check(0 <= rngalgo < 17,
"Invalid rngalgo. Must be int between [0, 16].")
def check_nonnegative(var: str):
"""Ensure that `var` is positive"""
@make_check
def run_check(ns, **kwargs):
val = ns[var]
check(val >= 0, f"'{var}' must be positive or equal zero, but it is {val!r}.")
return run_check
def make_dataset_dir(path):
"""Creates directory at path location."""
if path.exists():
log.warning(f"Dataset output folder exists: {path} Overwriting contents")
else:
path.mkdir(exist_ok=True)
def export_mask(path, mask):
"""Dump mask to file"""
np.savetxt(path, mask, fmt='%d')
@check_rngalgo
@check_nonnegative('filterseed')
@check_nonnegative('seed')
def prepare_nnpdf_rng(filterseed:int, rngalgo:int, seed:int):
"""Initialise the internal NNPDF RNG, specified by ``rngalgo`` which must
be an integer between 0 and 16, seeded with ``filterseed``.
The RNG can then be subsequently used to i.e generate pseudodata.
"""
log.info("Initialising RNG")
RandomGenerator.InitRNG(rngalgo, seed)
RandomGenerator.GetRNG().SetSeed(filterseed)
@check_positive('errorsize')
def filter_closure_data(filter_path, data, t0pdfset, fakenoise, errorsize, prepare_nnpdf_rng):
"""Filter closure data. In addition to cutting data points, the data is
generated from an underlying ``t0pdfset``, applying a shift to the data
if ``fakenoise`` is ``True``, which emulates the experimental central values
being shifted away from the underlying law.
"""
log.info('Filtering closure-test data.')
return _filter_closure_data(
filter_path, data, t0pdfset, fakenoise, errorsize)
@check_positive("errorsize")
def filter_closure_data_by_experiment(
filter_path, experiments_data, t0pdfset, fakenoise, errorsize, prepare_nnpdf_rng,
):
"""
Like :py:func:`filter_closure_data` except filters data by experiment.
This function just peforms a ``for`` loop over ``experiments``, the reason
we don't use ``reportengine.collect`` is that it can permute the order
in which closure data is generate, which means that the pseudodata is
not reproducible.
"""
return [
_filter_closure_data(filter_path, exp, t0pdfset, fakenoise, errorsize)
for exp in experiments_data
]
def filter_real_data(filter_path, data):
"""Filter real data, cutting any points which do not pass the filter rules.
"""
log.info('Filtering real data.')
return _filter_real_data(filter_path, data)
def filter(filter_data):
"""Summarise filters applied to all datasets"""
total_data, total_cut_data = np.atleast_2d(filter_data).sum(axis=0)
log.info(f'Summary: {total_cut_data}/{total_data} datapoints passed kinematic cuts.')
def _write_ds_cut_data(path, dataset):
make_dataset_dir(path)
all_dsndata = dataset.commondata.ndata
datamask = dataset.cuts.load()
if datamask is None:
filtered_dsndata = all_dsndata
log.info("All {all_ndata} points in in {dataset.name} passed kinematic cuts.")
else:
filtered_dsndata = len(datamask)
log.info(f"{len(datamask)}/{all_dsndata} datapoints "
f"in {dataset.name} passed kinematic cuts.")
# save to disk
if datamask is not None:
export_mask(path / f'FKMASK_{dataset.name}.dat', datamask)
return all_dsndata, filtered_dsndata
def _filter_real_data(filter_path, data):
"""Filter real experimental data."""
total_data_points = 0
total_cut_data_points = 0
for dataset in data.datasets:
path = filter_path / dataset.name
nfull, ncut = _write_ds_cut_data(path, dataset)
total_data_points += nfull
total_cut_data_points += ncut
dataset.load().Export(str(path))
return total_data_points, total_cut_data_points
def _filter_closure_data(filter_path, data, fakepdfset, fakenoise, errorsize):
"""Filter closure test data."""
total_data_points = 0
total_cut_data_points = 0
fakeset = fakepdfset.load()
# Load data, don't cache result
loaded_data = data.load.__wrapped__(data)
# generate level 1 shift if fakenoise
loaded_data.MakeClosure(fakeset.as_libNNPDF(), fakenoise)
for j, dataset in enumerate(data.datasets):
path = filter_path / dataset.name
nfull, ncut = _write_ds_cut_data(path, dataset)
total_data_points += nfull
total_cut_data_points += ncut
loaded_ds = loaded_data.GetSet(j)
if errorsize != 1.0:
loaded_ds.RescaleErrors(errorsize)
loaded_ds.Export(str(path))
return total_data_points, total_cut_data_points
def check_t0pdfset(t0pdfset):
"""T0 pdf check"""
t0pdfset.load()
log.info(f'{t0pdfset} T0 checked.')
def check_positivity(posdatasets):
"""Verify positive datasets are ready for the fit."""
log.info('Verifying positivity tables:')
for pos in posdatasets:
pos.load()
log.info(f'{pos.name} checked.')
def check_integrability(integdatasets):
"""Verify positive datasets are ready for the fit."""
log.info('Verifying integrability tables:')
for integ in integdatasets:
integ.load()
log.info(f'{integ.name} checked.')
class PerturbativeOrder:
"""Class that conveniently handles
perturbative order declarations for use
within the Rule class filter.
Parameters
----------
string: str
A string in the format of NNLO or equivalently N2LO.
This can be followed by one of ! + - or none.
The syntax allows for rules to be executed only if the perturbative
order is within a given range. The following enumerates all 4 cases
as an example:
NNLO+ only execute the following rule if the pto is 2 or greater
NNLO- only execute the following rule if the pto is strictly less than 2
NNLO! only execute the following rule if the pto is strictly not 2
NNLO only execute the following rule if the pto is exactly 2
Any unrecognized string will raise a BadPerturbativeOrder exception.
Example
-------
>>> from validphys.filters import PerturbativeOrder
>>> pto = PerturbativeOrder("NNLO+")
>>> pto.numeric_pto
2
>>> 1 in pto
False
>>> 2 in pto
True
>>> 3 in pto
True
"""
def __init__(self, string):
self.string = string.upper()
self.parse()
def parse(self):
# Change an input like NNNLO or N3LO
# to a numerical value for the pto.
# In this example, we assign
# self.numeric_pto to be 3.
exp = re.compile(
r"(N(?P<nnumber>\d+)|(?P<multiplens>N*))LO(?P<operator>[\+\-\!])?"
).fullmatch(self.string)
if not exp:
raise BadPerturbativeOrder(
f"String {self.string!r} does not represent a valid perturbative order specfication."
)
if exp.group("multiplens") is None:
self.numeric_pto = int(exp.group("nnumber"))
else:
self.numeric_pto = len(exp.group("multiplens"))
self.operator = exp.group("operator")
def __contains__(self, i):
if self.operator == "!":
return i != self.numeric_pto
elif self.operator == "+":
return i >= self.numeric_pto
elif self.operator == "-":
return i < self.numeric_pto
else:
return i == self.numeric_pto
class Rule:
"""Rule object to be used to generate cuts mask.
A rule object is created for each rule in ./cuts/filters.yaml
Parameters
----------
initial_data: dict
A dictionary containing all the information regarding the rule.
This contains the name of the dataset the rule to applies to
and/or the process type the rule applies to. Additionally, the
rule itself is defined, alongside the reason the rule is used.
Finally, the user can optionally define their own custom local
variables.
By default these are defined in cuts/filters.yaml
defaults: dict
A dictionary containing default values to be used globally in
all rules.
By default these are defined in cuts/defaults.yaml
theory_parameters:
Dict containing pairs of (theory_parameter, value)
loader: validphys.loader.Loader, optional
A loader instance used to retrieve the datasets.
"""
numpy_functions = {"sqrt": np.sqrt, "log": np.log, "fabs": np.fabs}
def __init__(
self,
initial_data: dict,
*,
defaults: dict,
theory_parameters: dict,
loader=None,
):
self.dataset = None
self.process_type = None
self._local_variables_code = {}
for key in initial_data:
setattr(self, key, initial_data[key])
if not hasattr(self, "rule"):
raise MissingRuleAttribute("No rule defined.")
if self.dataset is None and self.process_type is None:
raise MissingRuleAttribute(
"Please define either a process type or dataset."
)
if self.process_type is None:
from validphys.loader import Loader, LoaderError
if loader is None:
loader = Loader()
try:
cd = loader.check_commondata(self.dataset)
except LoaderError as e:
raise RuleProcessingError(
f"Could not find dataset {self.dataset}"
) from e
if cd.process_type[:3] == "DIS":
self.variables = CommonData.kinLabel["DIS"]
else:
self.variables = CommonData.kinLabel[cd.process_type]
else:
if self.process_type[:3] == "DIS":
self.variables = CommonData.kinLabel["DIS"]
else:
self.variables = CommonData.kinLabel[self.process_type]
if hasattr(self, "local_variables"):
if not isinstance(self.local_variables, Mapping):
raise RuleProcessingError(
f"Expecting local_variables to be a Mapping, not {type(self.local_variables)}."
)
else:
self.local_variables = {}
if hasattr(self, "PTO"):
if not isinstance(self.PTO, str):
raise RuleProcessingError(
f"Expecting PTO to be a string, not {type(self.PTO)}."
)
try:
self.PTO = PerturbativeOrder(self.PTO)
except BadPerturbativeOrder as e:
raise RuleProcessingError(e) from e
self.rule_string = self.rule
self.defaults = defaults
self.theory_params = theory_parameters
ns = {
*self.numpy_functions,
*self.defaults,
*self.variables,
"idat",
"central_value",
}
for k, v in self.local_variables.items():
try:
self._local_variables_code[k] = lcode = compile(
str(v), f"local variable {k}", "eval"
)
except Exception as e:
raise RuleProcessingError(
f"Could not process local variable {k!r} ({v!r}): {e}"
) from e
for name in lcode.co_names:
if name not in ns:
raise RuleProcessingError(
f"Could not process local variable {k!r}: Unknown name {name!r}"
)
ns.add(k)
try:
self.rule = compile(self.rule, "rule", "eval")
except Exception as e:
raise RuleProcessingError(
f"Could not process rule {self.rule_string!r}: {e}"
) from e
for name in self.rule.co_names:
if name not in ns:
raise RuleProcessingError(
f"Could not process rule {self.rule_string!r}: Unknown name {name!r}"
)
@property
def _properties(self):
"""Attributes of the Rule class that are defining. Two
Rules with identical ``_properties`` are considered equal.
"""
return (self.rule_string, self.dataset, self.process_type, self.theory_params['ID'])
def __eq__(self, other):
return self._properties == other._properties
def __hash__(self):
return hash(self._properties)
def __call__(self, dataset, idat):
central_value = dataset.GetData(idat)
# We return None if the rule doesn't apply. This
# is different to the case where the rule does apply,
# but the point was cut out by the rule.
if (
dataset.GetSetName() != self.dataset
and dataset.GetProc(idat) != self.process_type
and self.process_type != "DIS_ALL"
):
return None
# Handle the generalised DIS cut
if self.process_type == "DIS_ALL" and dataset.GetProc(idat)[:3] != "DIS":
return None
ns = self._make_point_namespace(dataset, idat)
for k, v in self.theory_params.items():
if k == "PTO" and hasattr(self, "PTO"):
if v not in self.PTO:
return None
elif hasattr(self, k) and (
getattr(self, k) != v
):
return None
# Will return True if datapoint passes through the filter
try:
return eval(
self.rule,
self.numpy_functions,
{
**{"idat": idat, "central_value": central_value},
**self.defaults,
**ns,
},
)
except Exception as e: # pragma: no cover
raise FatalRuleError(
f"Error when applying rule {self.rule_string!r}: {e}"
) from e
def __repr__(self): # pragma: no cover
return self.rule_string
def _make_kinematics_dict(self, dataset, idat) -> dict:
"""Fill in a dictionary with the kinematics for each point"""
kinematics = [dataset.GetKinematics(idat, j) for j in range(3)]
return dict(zip(self.variables, kinematics))
def _make_point_namespace(self, dataset, idat) -> dict:
"""Return a dictionary with kinematics and local
variables evaluated for each point"""
ns = self._make_kinematics_dict(dataset, idat)
for key, value in self._local_variables_code.items():
ns[key] = eval(value, {**self.numpy_functions, **ns})
return ns
def get_cuts_for_dataset(commondata, rules) -> list:
"""Function to generate a list containing the index
of all experimental points that passed kinematic
cut rules stored in ./cuts/filters.yaml
Parameters
----------
commondata: NNPDF CommonData spec
rules: List[Rule]
A list of Rule objects specifying the filters.
Returns
-------
mask: list
List object containing index of all passed experimental
values
Example
-------
>>> from validphys.filters import (get_cuts_for_dataset, Rule,
... default_filter_settings, default_filter_rules_input)
>>> from validphys.loader import Loader
>>> l = Loader()
>>> cd = l.check_commondata("NMC")
>>> theory = l.check_theoryID(53)
>>> filter_defaults = default_filter_settings()
>>> params = theory.get_description()
>>> rule_list = [Rule(initial_data=i, defaults=filter_defaults, theory_parameters=params)
... for i in default_filter_rules_input()]
>>> get_cuts_for_dataset(cd, rules=rule_list)
"""
dataset = commondata.load()
mask = []
for idat in range(dataset.GetNData()):
broken = False
for rule in rules:
rule_result = rule(dataset, idat)
if rule_result is not None and not rule_result:
broken = True
break
if not broken:
mask.append(idat)
return mask