@@ -129,7 +129,6 @@ def make_replica(
129129 sep_mult = False ,
130130 genrep = True ,
131131 max_tries = int (1e6 ),
132- resample_negative_pseudodata = False ,
133132):
134133 """Function that takes in a central value array and a covariance matrix
135134 and returns a pseudodata replica accounting for
@@ -154,6 +153,9 @@ def make_replica(
154153 group_multiplicative_errors: dict
155154 Dictionary containing the multiplicative uncertainties contribution to the pseudodata replica.
156155
156+ group_positivity_mask: np.array
157+ Boolean array of shape (N_dat,) indicating which data points should be positive.
158+
157159 sep_mult: bool
158160 Specifies whether computing the shifts with the full covmat
159161 or whether multiplicative errors should be separated
@@ -166,9 +168,6 @@ def make_replica(
166168 If after max_tries (default=1e6) no physical configuration is found,
167169 it will raise a :py:class:`ReplicaGenerationError`
168170
169- resample_negative_pseudodata: bool
170- When True, replicas that produce negative predictions will be resampled for ``max_tries``
171- until all points are positive (default: False)
172171 Returns
173172 -------
174173 pseudodata: np.array
@@ -202,7 +201,7 @@ def make_replica(
202201 full_mask = (
203202 group_positivity_mask
204203 if group_positivity_mask is not None
205- else np .ones_like (central_values_array , dtype = bool )
204+ else np .zeros_like (central_values_array , dtype = bool )
206205 )
207206 # The inner while True loop is for ensuring a positive definite
208207 # pseudodata replica
@@ -236,7 +235,7 @@ def make_replica(
236235 # Shifting pseudodata
237236 shifted_pseudodata = (central_values_array + shifts ) * mult_part
238237 # positivity control
239- if np .all (shifted_pseudodata [full_mask ] >= 0 ) or not resample_negative_pseudodata :
238+ if np .all (shifted_pseudodata [full_mask ] >= 0 ):
240239 return shifted_pseudodata
241240
242241 # Find which dataset index corresponds to the negative points, and print it out for debugging purposes
@@ -291,19 +290,23 @@ def group_multiplicative_errors(groups_dataset_inputs_loaded_cd_with_cuts, sep_m
291290 return multiplicative_errors
292291
293292
294- def group_positivity_mask (groups_dataset_inputs_loaded_cd_with_cuts ):
293+ def group_positivity_mask (
294+ groups_dataset_inputs_loaded_cd_with_cuts , resample_negative_pseudodata = False
295+ ):
295296 """Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData`
296297 and returns a boolean mask indicating which data points should be positive.
297298 """
298-
299- check_positive_masks = []
300- for cd in groups_dataset_inputs_loaded_cd_with_cuts :
301- if "ASY" in cd .commondataproc or cd .commondataproc .endswith ("_POL" ):
302- check_positive_masks .append (np .zeros_like (cd .central_values .to_numpy (), dtype = bool ))
303- else :
304- check_positive_masks .append (np .ones_like (cd .central_values .to_numpy (), dtype = bool ))
305- full_mask = np .concatenate (check_positive_masks , axis = 0 )
306- return full_mask
299+ if not resample_negative_pseudodata :
300+ return None
301+ else :
302+ check_positive_masks = []
303+ for cd in groups_dataset_inputs_loaded_cd_with_cuts :
304+ if "ASY" in cd .commondataproc or cd .commondataproc .endswith ("_POL" ):
305+ check_positive_masks .append (np .zeros_like (cd .central_values .to_numpy (), dtype = bool ))
306+ else :
307+ check_positive_masks .append (np .ones_like (cd .central_values .to_numpy (), dtype = bool ))
308+ full_mask = np .concatenate (check_positive_masks , axis = 0 )
309+ return full_mask
307310
308311
309312def indexed_make_replica (groups_index , make_replica ):
0 commit comments