Skip to content

Commit 581178f

Browse files
committed
Added mask func and docstrings
1 parent 60e5f75 commit 581178f

File tree

2 files changed

+51
-29
lines changed

2 files changed

+51
-29
lines changed

validphys2/src/validphys/n3fit_data.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,19 @@ def replica_luxseed(replica, luxseed):
111111

112112

113113
def group_replica_mcseed(replica_mcseed, groups_dataset_inputs_loaded_cd_with_cuts):
114+
"""Generates the ``mcseed`` for a group of datasets. This is done by hashing the names
115+
of the datasets in the group and adding it to the ``replica_mcseed`
116+
Parameters
117+
---------
118+
groups_dataset_inputs_loaded_cd_with_cuts: list[:py:class:`nnpdf_data.coredata.CommonData`]
119+
List of CommonData objects which stores information about systematic errors,
120+
their treatment and description, for each dataset.
121+
replica_mcseed: int
122+
"""
114123
names_for_salt = []
124+
# Try to use the new dataset name, but make older runs reproducible by keeping the old names.
125+
# WARNING: don't rely on this behaviour, this might be removed in future releases
126+
115127
for loaded_cd in groups_dataset_inputs_loaded_cd_with_cuts:
116128
if loaded_cd.legacy_names is None:
117129
names_for_salt.append(loaded_cd.setname)

validphys2/src/validphys/pseudodata.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,18 @@ def read_replica_pseudodata(fit, context_index, replica):
121121

122122

123123
def make_replica(
124-
groups_dataset_inputs_loaded_cd_with_cuts,
125124
central_values_array,
126125
group_replica_mcseed,
127126
dataset_inputs_sampling_covmat,
128-
group_multiplicative_errors,
127+
group_multiplicative_errors=None,
128+
group_positivity_mask=None,
129129
sep_mult=False,
130130
genrep=True,
131131
max_tries=int(1e6),
132132
resample_negative_pseudodata=False,
133133
):
134-
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData`
135-
objects and returns a pseudodata replica accounting for
134+
"""Function that takes in a central value array and a covariance matrix
135+
and returns a pseudodata replica accounting for
136136
possible correlations between systematic uncertainties.
137137
138138
The function loops until positive definite pseudodata is generated for any
@@ -141,17 +141,19 @@ def make_replica(
141141
142142
Parameters
143143
---------
144-
groups_dataset_inputs_loaded_cd_with_cuts: list[:py:class:`nnpdf_data.coredata.CommonData`]
145-
List of CommonData objects which stores information about systematic errors,
146-
their treatment and description, for each dataset.
144+
central_values_array: np.array
145+
Numpy array which is N_dat (where N_dat is the combined number of data points after cuts)
146+
containing the central values of the data.
147147
148-
replica_mcseed: int, None
149-
Seed used to initialise the numpy random number generator. If ``None`` then a random seed is
150-
allocated using the default numpy behaviour.
148+
group_replica_mcseed: int
149+
Seed used to initialise the numpy random number generator.
151150
152151
dataset_inputs_sampling_covmat: np.array
153152
Full covmat to be used. It can be either only experimental or also theoretical.
154153
154+
group_multiplicative_errors: dict
155+
Dictionary containing the multiplicative uncertainties contribution to the pseudodata replica.
156+
155157
sep_mult: bool
156158
Specifies whether computing the shifts with the full covmat
157159
or whether multiplicative errors should be separated
@@ -197,11 +199,11 @@ def make_replica(
197199
covmat = dataset_inputs_sampling_covmat
198200
covmat_sqrt = sqrt_covmat(covmat)
199201

200-
all_pseudodata = central_values_array
201-
if group_multiplicative_errors is not None:
202-
full_mask = group_multiplicative_errors["full_mask"]
203-
else:
204-
full_mask = np.ones_like(central_values_array, dtype=bool)
202+
full_mask = (
203+
group_positivity_mask
204+
if group_positivity_mask is not None
205+
else np.ones_like(central_values_array, dtype=bool)
206+
)
205207
# The inner while True loop is for ensuring a positive definite
206208
# pseudodata replica
207209
for _ in range(max_tries):
@@ -232,18 +234,16 @@ def make_replica(
232234
).prod(axis=1)
233235
mult_part = np.concatenate(mult_shifts, axis=0) * special_mult
234236
# Shifting pseudodata
235-
shifted_pseudodata = (all_pseudodata + shifts) * mult_part
237+
shifted_pseudodata = (central_values_array + shifts) * mult_part
236238
# positivity control
237239
if np.all(shifted_pseudodata[full_mask] >= 0) or not resample_negative_pseudodata:
238240
return shifted_pseudodata
239241

240-
dfail = " ".join(i.setname for i in groups_dataset_inputs_loaded_cd_with_cuts)
241-
log.error(f"Error generating replicas for the group: {dfail}")
242242
raise ReplicaGenerationError(f"No valid replica found after {max_tries} attempts")
243243

244244

245245
def central_values_array(groups_dataset_inputs_loaded_cd_with_cuts):
246-
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData
246+
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData`
247247
and returns the central values concatenated in a single array.
248248
"""
249249
central_values = []
@@ -253,11 +253,12 @@ def central_values_array(groups_dataset_inputs_loaded_cd_with_cuts):
253253

254254

255255
def group_multiplicative_errors(groups_dataset_inputs_loaded_cd_with_cuts, sep_mult):
256-
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData
256+
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData`
257257
and returns the multiplicative uncertainties contribution to the pseudodata replica.
258258
"""
259+
if not sep_mult:
260+
return None
259261

260-
check_positive_masks = []
261262
nonspecial_mult = []
262263
special_mult = []
263264
special_mult_errors = []
@@ -270,25 +271,34 @@ def group_multiplicative_errors(groups_dataset_inputs_loaded_cd_with_cuts, sep_m
270271
special_mult.append(
271272
mult_errors.loc[:, ~mult_errors.columns.isin(INTRA_DATASET_SYS_NAME)]
272273
)
273-
if "ASY" in cd.commondataproc or cd.commondataproc.endswith("_POL"):
274-
check_positive_masks.append(np.zeros_like(cd.central_values.to_numpy(), dtype=bool))
275-
else:
276-
check_positive_masks.append(np.ones_like(cd.central_values.to_numpy(), dtype=bool))
277-
# concatenating special multiplicative errors, pseudodatas and positive mask
274+
275+
# concatenating special multiplicative errors
278276
if sep_mult:
279277
special_mult_errors = pd.concat(special_mult, axis=0, sort=True).fillna(0).to_numpy()
280278

281-
full_mask = np.concatenate(check_positive_masks, axis=0)
282-
283279
multiplicative_errors = {
284280
"nonspecial_mult": nonspecial_mult,
285281
"special_mult": special_mult_errors,
286-
"full_mask": full_mask,
287282
}
288283

289284
return multiplicative_errors
290285

291286

287+
def group_positivity_mask(groups_dataset_inputs_loaded_cd_with_cuts):
288+
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData`
289+
and returns a boolean mask indicating which data points should be positive.
290+
"""
291+
292+
check_positive_masks = []
293+
for cd in groups_dataset_inputs_loaded_cd_with_cuts:
294+
if "ASY" in cd.commondataproc or cd.commondataproc.endswith("_POL"):
295+
check_positive_masks.append(np.zeros_like(cd.central_values.to_numpy(), dtype=bool))
296+
else:
297+
check_positive_masks.append(np.ones_like(cd.central_values.to_numpy(), dtype=bool))
298+
full_mask = np.concatenate(check_positive_masks, axis=0)
299+
return full_mask
300+
301+
292302
def indexed_make_replica(groups_index, make_replica):
293303
"""Index the make_replica pseudodata appropriately"""
294304

0 commit comments

Comments
 (0)