Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions wmin/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,3 @@ def write_pod_basis(
" 2. Shift all others down by one index\n"
" 3. Make replica_1 the new central member of the post-fit basis"
)
log.warning(
"Reminder: decrement `NumMembers` by 1 in the LHAPDF .info file to reflect the removed member."
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
pod_basis_0002.dat -> pod_basis_0001.dat
...
If a file with index 0000 exists, it will be deleted before renaming to avoid collisions.
Also updates the NumMembers field in any .info files.
"""

import os
Expand Down Expand Up @@ -86,6 +87,20 @@ def main():
os.remove(dst)
os.rename(tmp, dst)

# Update NumMembers in .info files
info_file = [f for f in all_files if f.endswith(".info")][0]
num_members = len(to_shift) # number of shifted files = number of final members

info_path = os.path.join(args.directory, info_file)
print(f"Updating NumMembers to {num_members} in: {info_path}")
if not args.dry_run:
with open(info_path, "r") as f:
content = f.read()
# Replace REPLACE_NREP or any existing NumMembers value
content = re.sub(r"NumMembers:.*", f"NumMembers: {num_members}", content)
with open(info_path, "w") as f:
f.write(content)

print("Done.")


Expand Down
28 changes: 25 additions & 3 deletions wmin/tests/test_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import time

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jla
import pytest
from colibri.api import API as colibriAPI
from colibri.loss_functions import chi2
from colibri.bayes_prior import bayesian_prior
from colibri.core import BayesianPrior
from colibri.tests.conftest import (
T0_PDFSET,
TEST_DATASETS,
Expand Down Expand Up @@ -45,6 +47,26 @@
RNG_KEY = 0


def mock_prior_transform(x):
return x


def mock_log_prob(x):
return jnp.array(0.0)


def mock_sample(rng_key, n_samples):
n_params = len(MOCK_PDF_MODEL.param_names)
return jax.random.uniform(rng_key, shape=(n_samples, n_params))


bayesian_prior = BayesianPrior(
prior_transform=lambda x: x,
log_prob=lambda x: -jnp.sum(x**2, axis=-1),
sample=lambda rng, n: jnp.zeros((n, MOCK_PDF_MODEL.n_parameters)),
)


def prior_samples(prior, wmin_model_settings):
# Sample params from the prior
prior_samples = []
Expand Down Expand Up @@ -86,7 +108,7 @@ def test_likelihood_dis_wmin(wmin_model_settings):
**{**wmin_model_settings, "output_path": None, "dump_model": False}
)
# get bayesian prior
prior = bayesian_prior(TEST_PRIOR_SETTINGS_WMIN, MOCK_PDF_MODEL)
prior = bayesian_prior.prior_transform

pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=forward_map)

Expand Down Expand Up @@ -145,7 +167,7 @@ def test_likelihood_had_wmin(wmin_model_settings):
)

# get bayesian prior
prior = bayesian_prior(TEST_PRIOR_SETTINGS_WMIN, MOCK_PDF_MODEL)
prior = bayesian_prior.prior_transform

pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=forward_map)

Expand Down Expand Up @@ -204,7 +226,7 @@ def test_likelihood_global_wmin(wmin_model_settings):
)

# get bayesian prior
prior = bayesian_prior(TEST_PRIOR_SETTINGS_WMIN, MOCK_PDF_MODEL)
prior = bayesian_prior.prior_transform

pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=forward_map)

Expand Down
Loading