@@ -122,8 +122,10 @@ def read_replica_pseudodata(fit, context_index, replica):
122122
123123def make_replica (
124124 groups_dataset_inputs_loaded_cd_with_cuts ,
125+ central_values_array ,
125126 group_replica_mcseed ,
126127 dataset_inputs_sampling_covmat ,
128+ replica_multiplicative_errors ,
127129 sep_mult = False ,
128130 genrep = True ,
129131 max_tries = int (1e6 ),
@@ -187,73 +189,44 @@ def make_replica(
187189 0.34206012, 0.31866286, 0.2790856 , 0.33257621, 0.33680007,
188190 """
189191 if not genrep :
190- return np .concatenate (
191- [cd .central_values for cd in groups_dataset_inputs_loaded_cd_with_cuts ]
192- )
192+ return central_values_array
193193
194- # Set random seed to replica_mcseed - Would also like to change this seed for each group via 'groupname' (this can't yet be accessed here)
194+ # Set random seed
195195 rng = np .random .default_rng (seed = group_replica_mcseed )
196196 # construct covmat
197197 covmat = dataset_inputs_sampling_covmat
198198 covmat_sqrt = sqrt_covmat (covmat )
199- # Loading the data
200- pseudodatas = []
201- check_positive_masks = []
202- nonspecial_mult = []
203- special_mult = []
204- for cd in groups_dataset_inputs_loaded_cd_with_cuts :
205- # copy here to avoid mutating the central values.
206- is_commondata = hasattr (cd , "central_values" ) and cd .central_values is not None
207- if is_commondata :
208- pseudodata = cd .central_values .to_numpy ()
209- else :
210- pseudodata = np .asarray (cd )
211-
212- pseudodatas .append (pseudodata )
213- # Separation of multiplicative errors. If sep_mult is True also the exp_covmat is produced
214- # without multiplicative errors
215- if is_commondata == True :
216- if sep_mult :
217- mult_errors = cd .multiplicative_errors
218- mult_uncorr_errors = mult_errors .loc [:, mult_errors .columns == "UNCORR" ].to_numpy ()
219- mult_corr_errors = mult_errors .loc [:, mult_errors .columns == "CORR" ].to_numpy ()
220- nonspecial_mult .append ((mult_uncorr_errors , mult_corr_errors ))
221- special_mult .append (
222- mult_errors .loc [:, ~ mult_errors .columns .isin (INTRA_DATASET_SYS_NAME )]
223- )
224- if "ASY" in cd .commondataproc or cd .commondataproc .endswith ("_POL" ):
225- check_positive_masks .append (np .zeros_like (pseudodata , dtype = bool ))
226- else :
227- check_positive_masks .append (np .ones_like (pseudodata , dtype = bool ))
228- # If the input is not a commondata instance, then we assume there are no multiplicative errors and that all points must be positive
229- else :
230- check_positive_masks .append (np .ones_like (pseudodata , dtype = bool ))
231- # concatenating special multiplicative errors, pseudodatas and positive mask
232- if sep_mult :
233- special_mult_errors = pd .concat (special_mult , axis = 0 , sort = True ).fillna (0 ).to_numpy ()
234- all_pseudodata = np .concatenate (pseudodatas , axis = 0 )
235- full_mask = np .concatenate (check_positive_masks , axis = 0 )
199+
200+ all_pseudodata = central_values_array
201+ if replica_multiplicative_errors is not None :
202+ full_mask = replica_multiplicative_errors ["full_mask" ]
203+ else :
204+ full_mask = np .ones_like (central_values_array , dtype = bool )
236205 # The inner while True loop is for ensuring a positive definite
237206 # pseudodata replica
238207 for _ in range (max_tries ):
239208 mult_shifts = []
240209 # Prepare the per-dataset multiplicative shifts
241- for mult_uncorr_errors , mult_corr_errors in nonspecial_mult :
242- # convert to from percent to fraction
243- mult_shift = (
244- 1 + mult_uncorr_errors * rng .normal (size = mult_uncorr_errors .shape ) / 100
245- ).prod (axis = 1 )
210+ if replica_multiplicative_errors is not None :
211+ for mult_uncorr_errors , mult_corr_errors in replica_multiplicative_errors [
212+ "nonspecial_mult"
213+ ]:
214+ # convert to from percent to fraction
215+ mult_shift = (
216+ 1 + mult_uncorr_errors * rng .normal (size = mult_uncorr_errors .shape ) / 100
217+ ).prod (axis = 1 )
246218
247- mult_shift *= (
248- 1 + mult_corr_errors * rng .normal (size = (1 , mult_corr_errors .shape [1 ])) / 100
249- ).prod (axis = 1 )
219+ mult_shift *= (
220+ 1 + mult_corr_errors * rng .normal (size = (1 , mult_corr_errors .shape [1 ])) / 100
221+ ).prod (axis = 1 )
250222
251- mult_shifts .append (mult_shift )
223+ mult_shifts .append (mult_shift )
252224
253225 # If sep_mult is true then the multiplicative shifts were not included in the covmat
254226 shifts = covmat_sqrt @ rng .normal (size = covmat .shape [1 ])
255227 mult_part = 1.0
256228 if sep_mult :
229+ special_mult_errors = replica_multiplicative_errors ["special_mult" ]
257230 special_mult = (
258231 1 + special_mult_errors * rng .normal (size = (1 , special_mult_errors .shape [1 ])) / 100
259232 ).prod (axis = 1 )
@@ -269,6 +242,53 @@ def make_replica(
269242 raise ReplicaGenerationError (f"No valid replica found after { max_tries } attempts" )
270243
271244
245+ def central_values_array (groups_dataset_inputs_loaded_cd_with_cuts ):
246+ """Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData
247+ and returns the central values concatenated in a single array.
248+ """
249+ central_values = []
250+ for cd in groups_dataset_inputs_loaded_cd_with_cuts :
251+ central_values .append (cd .central_values .to_numpy ())
252+ return np .concatenate (central_values , axis = 0 )
253+
254+
255+ def replica_multiplicative_errors (groups_dataset_inputs_loaded_cd_with_cuts , sep_mult ):
256+ """Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData
257+ and returns the multiplicative uncertainties contribution to the pseudodata replica.
258+ """
259+
260+ check_positive_masks = []
261+ nonspecial_mult = []
262+ special_mult = []
263+ special_mult_errors = []
264+ for cd in groups_dataset_inputs_loaded_cd_with_cuts :
265+ if sep_mult :
266+ mult_errors = cd .multiplicative_errors
267+ mult_uncorr_errors = mult_errors .loc [:, mult_errors .columns == "UNCORR" ].to_numpy ()
268+ mult_corr_errors = mult_errors .loc [:, mult_errors .columns == "CORR" ].to_numpy ()
269+ nonspecial_mult .append ((mult_uncorr_errors , mult_corr_errors ))
270+ special_mult .append (
271+ mult_errors .loc [:, ~ mult_errors .columns .isin (INTRA_DATASET_SYS_NAME )]
272+ )
273+ if "ASY" in cd .commondataproc or cd .commondataproc .endswith ("_POL" ):
274+ check_positive_masks .append (np .zeros_like (cd .central_values .to_numpy (), dtype = bool ))
275+ else :
276+ check_positive_masks .append (np .ones_like (cd .central_values .to_numpy (), dtype = bool ))
277+ # concatenating special multiplicative errors, pseudodatas and positive mask
278+ if sep_mult :
279+ special_mult_errors = pd .concat (special_mult , axis = 0 , sort = True ).fillna (0 ).to_numpy ()
280+
281+ full_mask = np .concatenate (check_positive_masks , axis = 0 )
282+
283+ multiplicative_errors = {
284+ "nonspecial_mult" : nonspecial_mult ,
285+ "special_mult" : special_mult_errors ,
286+ "full_mask" : full_mask ,
287+ }
288+
289+ return multiplicative_errors
290+
291+
272292def indexed_make_replica (groups_index , make_replica ):
273293 """Index the make_replica pseudodata appropriately"""
274294
0 commit comments