Skip to content

Commit 3726aee

Browse files
committed
Make_replica take central value array and mult_unc dict
1 parent ec2feae commit 3726aee

File tree

1 file changed

+70
-50
lines changed

1 file changed

+70
-50
lines changed

validphys2/src/validphys/pseudodata.py

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ def read_replica_pseudodata(fit, context_index, replica):
122122

123123
def make_replica(
124124
groups_dataset_inputs_loaded_cd_with_cuts,
125+
central_values_array,
125126
group_replica_mcseed,
126127
dataset_inputs_sampling_covmat,
128+
replica_multiplicative_errors,
127129
sep_mult=False,
128130
genrep=True,
129131
max_tries=int(1e6),
@@ -187,73 +189,44 @@ def make_replica(
187189
0.34206012, 0.31866286, 0.2790856 , 0.33257621, 0.33680007,
188190
"""
189191
if not genrep:
190-
return np.concatenate(
191-
[cd.central_values for cd in groups_dataset_inputs_loaded_cd_with_cuts]
192-
)
192+
return central_values_array
193193

194-
# Set random seed to replica_mcseed - Would also like to change this seed for each group via 'groupname' (this can't yet be accessed here)
194+
# Set random seed
195195
rng = np.random.default_rng(seed=group_replica_mcseed)
196196
# construct covmat
197197
covmat = dataset_inputs_sampling_covmat
198198
covmat_sqrt = sqrt_covmat(covmat)
199-
# Loading the data
200-
pseudodatas = []
201-
check_positive_masks = []
202-
nonspecial_mult = []
203-
special_mult = []
204-
for cd in groups_dataset_inputs_loaded_cd_with_cuts:
205-
# copy here to avoid mutating the central values.
206-
is_commondata = hasattr(cd, "central_values") and cd.central_values is not None
207-
if is_commondata:
208-
pseudodata = cd.central_values.to_numpy()
209-
else:
210-
pseudodata = np.asarray(cd)
211-
212-
pseudodatas.append(pseudodata)
213-
# Separation of multiplicative errors. If sep_mult is True also the exp_covmat is produced
214-
# without multiplicative errors
215-
if is_commondata == True:
216-
if sep_mult:
217-
mult_errors = cd.multiplicative_errors
218-
mult_uncorr_errors = mult_errors.loc[:, mult_errors.columns == "UNCORR"].to_numpy()
219-
mult_corr_errors = mult_errors.loc[:, mult_errors.columns == "CORR"].to_numpy()
220-
nonspecial_mult.append((mult_uncorr_errors, mult_corr_errors))
221-
special_mult.append(
222-
mult_errors.loc[:, ~mult_errors.columns.isin(INTRA_DATASET_SYS_NAME)]
223-
)
224-
if "ASY" in cd.commondataproc or cd.commondataproc.endswith("_POL"):
225-
check_positive_masks.append(np.zeros_like(pseudodata, dtype=bool))
226-
else:
227-
check_positive_masks.append(np.ones_like(pseudodata, dtype=bool))
228-
# If the input is not a commondata instance, then we assume there are no multiplicative errors and that all points must be positive
229-
else:
230-
check_positive_masks.append(np.ones_like(pseudodata, dtype=bool))
231-
# concatenating special multiplicative errors, pseudodatas and positive mask
232-
if sep_mult:
233-
special_mult_errors = pd.concat(special_mult, axis=0, sort=True).fillna(0).to_numpy()
234-
all_pseudodata = np.concatenate(pseudodatas, axis=0)
235-
full_mask = np.concatenate(check_positive_masks, axis=0)
199+
200+
all_pseudodata = central_values_array
201+
if replica_multiplicative_errors is not None:
202+
full_mask = replica_multiplicative_errors["full_mask"]
203+
else:
204+
full_mask = np.ones_like(central_values_array, dtype=bool)
236205
# The inner while True loop is for ensuring a positive definite
237206
# pseudodata replica
238207
for _ in range(max_tries):
239208
mult_shifts = []
240209
# Prepare the per-dataset multiplicative shifts
241-
for mult_uncorr_errors, mult_corr_errors in nonspecial_mult:
242-
# convert to from percent to fraction
243-
mult_shift = (
244-
1 + mult_uncorr_errors * rng.normal(size=mult_uncorr_errors.shape) / 100
245-
).prod(axis=1)
210+
if replica_multiplicative_errors is not None:
211+
for mult_uncorr_errors, mult_corr_errors in replica_multiplicative_errors[
212+
"nonspecial_mult"
213+
]:
214+
# convert to from percent to fraction
215+
mult_shift = (
216+
1 + mult_uncorr_errors * rng.normal(size=mult_uncorr_errors.shape) / 100
217+
).prod(axis=1)
246218

247-
mult_shift *= (
248-
1 + mult_corr_errors * rng.normal(size=(1, mult_corr_errors.shape[1])) / 100
249-
).prod(axis=1)
219+
mult_shift *= (
220+
1 + mult_corr_errors * rng.normal(size=(1, mult_corr_errors.shape[1])) / 100
221+
).prod(axis=1)
250222

251-
mult_shifts.append(mult_shift)
223+
mult_shifts.append(mult_shift)
252224

253225
# If sep_mult is true then the multiplicative shifts were not included in the covmat
254226
shifts = covmat_sqrt @ rng.normal(size=covmat.shape[1])
255227
mult_part = 1.0
256228
if sep_mult:
229+
special_mult_errors = replica_multiplicative_errors["special_mult"]
257230
special_mult = (
258231
1 + special_mult_errors * rng.normal(size=(1, special_mult_errors.shape[1])) / 100
259232
).prod(axis=1)
@@ -269,6 +242,53 @@ def make_replica(
269242
raise ReplicaGenerationError(f"No valid replica found after {max_tries} attempts")
270243

271244

245+
def central_values_array(groups_dataset_inputs_loaded_cd_with_cuts):
246+
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData
247+
and returns the central values concatenated in a single array.
248+
"""
249+
central_values = []
250+
for cd in groups_dataset_inputs_loaded_cd_with_cuts:
251+
central_values.append(cd.central_values.to_numpy())
252+
return np.concatenate(central_values, axis=0)
253+
254+
255+
def replica_multiplicative_errors(groups_dataset_inputs_loaded_cd_with_cuts, sep_mult):
256+
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData
257+
and returns the multiplicative uncertainties contribution to the pseudodata replica.
258+
"""
259+
260+
check_positive_masks = []
261+
nonspecial_mult = []
262+
special_mult = []
263+
special_mult_errors = []
264+
for cd in groups_dataset_inputs_loaded_cd_with_cuts:
265+
if sep_mult:
266+
mult_errors = cd.multiplicative_errors
267+
mult_uncorr_errors = mult_errors.loc[:, mult_errors.columns == "UNCORR"].to_numpy()
268+
mult_corr_errors = mult_errors.loc[:, mult_errors.columns == "CORR"].to_numpy()
269+
nonspecial_mult.append((mult_uncorr_errors, mult_corr_errors))
270+
special_mult.append(
271+
mult_errors.loc[:, ~mult_errors.columns.isin(INTRA_DATASET_SYS_NAME)]
272+
)
273+
if "ASY" in cd.commondataproc or cd.commondataproc.endswith("_POL"):
274+
check_positive_masks.append(np.zeros_like(cd.central_values.to_numpy(), dtype=bool))
275+
else:
276+
check_positive_masks.append(np.ones_like(cd.central_values.to_numpy(), dtype=bool))
277+
# concatenating special multiplicative errors, pseudodatas and positive mask
278+
if sep_mult:
279+
special_mult_errors = pd.concat(special_mult, axis=0, sort=True).fillna(0).to_numpy()
280+
281+
full_mask = np.concatenate(check_positive_masks, axis=0)
282+
283+
multiplicative_errors = {
284+
"nonspecial_mult": nonspecial_mult,
285+
"special_mult": special_mult_errors,
286+
"full_mask": full_mask,
287+
}
288+
289+
return multiplicative_errors
290+
291+
272292
def indexed_make_replica(groups_index, make_replica):
273293
"""Index the make_replica pseudodata appropriately"""
274294

0 commit comments

Comments
 (0)