Skip to content

Commit 71babd0

Browse files
committed
removed resample_negative_pseudodata from make_replica, put into group_positivity_mask
1 parent e03701b commit 71babd0

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

validphys2/src/validphys/pseudodata.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def make_replica(
129129
sep_mult=False,
130130
genrep=True,
131131
max_tries=int(1e6),
132-
resample_negative_pseudodata=False,
133132
):
134133
"""Function that takes in a central value array and a covariance matrix
135134
and returns a pseudodata replica accounting for
@@ -154,6 +153,9 @@ def make_replica(
154153
group_multiplicative_errors: dict
155154
Dictionary containing the multiplicative uncertainties contribution to the pseudodata replica.
156155
156+
group_positivity_mask: np.array
157+
Boolean array of shape (N_dat,) indicating which data points should be positive.
158+
157159
sep_mult: bool
158160
Specifies whether computing the shifts with the full covmat
159161
or whether multiplicative errors should be separated
@@ -166,9 +168,6 @@ def make_replica(
166168
If after max_tries (default=1e6) no physical configuration is found,
167169
it will raise a :py:class:`ReplicaGenerationError`
168170
169-
resample_negative_pseudodata: bool
170-
When True, replicas that produce negative predictions will be resampled for ``max_tries``
171-
until all points are positive (default: False)
172171
Returns
173172
-------
174173
pseudodata: np.array
@@ -202,7 +201,7 @@ def make_replica(
202201
full_mask = (
203202
group_positivity_mask
204203
if group_positivity_mask is not None
205-
else np.ones_like(central_values_array, dtype=bool)
204+
else np.zeros_like(central_values_array, dtype=bool)
206205
)
207206
# The inner while True loop is for ensuring a positive definite
208207
# pseudodata replica
@@ -236,7 +235,7 @@ def make_replica(
236235
# Shifting pseudodata
237236
shifted_pseudodata = (central_values_array + shifts) * mult_part
238237
# positivity control
239-
if np.all(shifted_pseudodata[full_mask] >= 0) or not resample_negative_pseudodata:
238+
if np.all(shifted_pseudodata[full_mask] >= 0):
240239
return shifted_pseudodata
241240

242241
# Find which dataset index corresponds to the negative points, and print it out for debugging purposes
@@ -291,19 +290,23 @@ def group_multiplicative_errors(groups_dataset_inputs_loaded_cd_with_cuts, sep_m
291290
return multiplicative_errors
292291

293292

294-
def group_positivity_mask(groups_dataset_inputs_loaded_cd_with_cuts):
293+
def group_positivity_mask(
294+
groups_dataset_inputs_loaded_cd_with_cuts, resample_negative_pseudodata=False
295+
):
295296
"""Function that takes in a list of :py:class:`nnpdf_data.coredata.CommonData`
296297
and returns a boolean mask indicating which data points should be positive.
297298
"""
298-
299-
check_positive_masks = []
300-
for cd in groups_dataset_inputs_loaded_cd_with_cuts:
301-
if "ASY" in cd.commondataproc or cd.commondataproc.endswith("_POL"):
302-
check_positive_masks.append(np.zeros_like(cd.central_values.to_numpy(), dtype=bool))
303-
else:
304-
check_positive_masks.append(np.ones_like(cd.central_values.to_numpy(), dtype=bool))
305-
full_mask = np.concatenate(check_positive_masks, axis=0)
306-
return full_mask
299+
if not resample_negative_pseudodata:
300+
return None
301+
else:
302+
check_positive_masks = []
303+
for cd in groups_dataset_inputs_loaded_cd_with_cuts:
304+
if "ASY" in cd.commondataproc or cd.commondataproc.endswith("_POL"):
305+
check_positive_masks.append(np.zeros_like(cd.central_values.to_numpy(), dtype=bool))
306+
else:
307+
check_positive_masks.append(np.ones_like(cd.central_values.to_numpy(), dtype=bool))
308+
full_mask = np.concatenate(check_positive_masks, axis=0)
309+
return full_mask
307310

308311

309312
def indexed_make_replica(groups_index, make_replica):

0 commit comments

Comments
 (0)