-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Describe the issue:
Hi all,
I've just updated my pymc to 5.25.1 and I'm finding that attempting to sample a HSGP with the ExpQuad covariance function is raising a SamplingError: Initial evaluation of model at starting point failed!. This appears to relate to the logp value of the HSGP at initialisation, which is inf according to the traceback (see below). There doesn't appear to be an issue when I swap out ExpQuad for something like Matern32. I've recreated this issue on google colab with pymc 5.25.1. I was working with pymc 5.23 recently on a different computer, which ran ExpQuad HSGPs fine, although I can't test this now.
Reproduceable code example:
import pymc as pm
import numpy as np
# fake some data
x = np.sort(np.random.uniform(-1, 1, 101))
y = 3*np.cos(x*0.9) - 1
y += np.random.normal(scale=0.05, size=101)
with pm.Model(coords={"basis_coeffs": np.arange(200), "obs_id": np.arange(y.size)}) as model:
ell = pm.Exponential("ell", scale=1) # dont @ me for these priors...
eta = pm.Exponential("eta", scale=1.0)
cov_func = eta**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=ell) # fails with pymc sampler, nutpie, numpyro, blackjax (so possibly not to do with nuts_sampler kwarg in pm.sample...)
#cov_func = eta**2 * pm.gp.cov.Matern32(input_dim=1, ls=ell) # this works with all
m, c = 200, 1.5
gp = pm.gp.HSGP(m=[m], c=c, parametrization="centered", cov_func=cov_func)
f = gp.prior("f", X=x[:, None], hsgp_coeffs_dims="basis_coeffs", gp_dims="obs_id")
sigma = pm.Exponential("sigma", scale=1.0)
pm.Normal("y_obs", mu=f, sigma=sigma, observed=y, dims="obs_id")
idata = pm.sample()Error message:
---------------------------------------------------------------------------
SamplingError Traceback (most recent call last)
Cell In[16], line 16
13 sigma = pm.Exponential("sigma", scale=1.0)
14 pm.Normal("y_obs", mu=f, sigma=sigma, observed=y, dims="obs_id")
---> 16 idata = pm.sample()
File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\sampling\mcmc.py:825](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/sampling/mcmc.py#line=824), in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
823 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
824 with joined_blas_limiter():
--> 825 initial_points, step = init_nuts(
826 init=init,
827 chains=chains,
828 n_init=n_init,
829 model=model,
830 random_seed=random_seed_list,
831 progressbar=progress_bool,
832 jitter_max_retries=jitter_max_retries,
833 tune=tune,
834 initvals=initvals,
835 compile_kwargs=compile_kwargs,
836 **kwargs,
837 )
838 else:
839 # Get initial points
840 ipfns = make_initial_point_fns_per_chain(
841 model=model,
842 overrides=initvals,
843 jitter_rvs=set(),
844 chains=chains,
845 )
File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\sampling\mcmc.py:1598](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/sampling/mcmc.py#line=1597), in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs)
1595 q, _ = DictToArrayBijection.map(ip)
1596 return logp_dlogp_func([q], extra_vars={})[0]
-> 1598 initial_points = _init_jitter(
1599 model,
1600 initvals,
1601 seeds=random_seed_list,
1602 jitter="jitter" in init,
1603 jitter_max_retries=jitter_max_retries,
1604 logp_fn=model_logp_fn,
1605 )
1607 apoints = [DictToArrayBijection.map(point) for point in initial_points]
1608 apoints_data = [apoint.data for apoint in apoints]
File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\sampling\mcmc.py:1479](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/sampling/mcmc.py#line=1478), in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn)
1476 if not np.isfinite(point_logp):
1477 if i == jitter_max_retries:
1478 # Print informative message on last attempted point
-> 1479 model.check_start_vals(point)
1480 # Retry with a new seed
1481 seed = rng.integers(2**30, dtype=np.int64)
File [~\.local\share\mamba\envs\pymc_dev\Lib\site-packages\pymc\model\core.py:1761](http://localhost:8888/~/.local/share/mamba/envs/pymc_dev/Lib/site-packages/pymc/model/core.py#line=1760), in Model.check_start_vals(self, start, **kwargs)
1758 initial_eval = self.point_logps(point=elem, **kwargs)
1760 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1761 raise SamplingError(
1762 "Initial evaluation of model at starting point failed!\n"
1763 f"Starting value[s:\n](file:///S:/n){elem}\n\n"
1764 f"Logp initial evaluation result[s:\n](file:///S:/n){initial_eval}\n"
1765 "You can call `model.debug()` for more details."
1766 )
SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'ell_log__': array(0.06154412), 'eta_log__': array(-0.07390129), 'f_hsgp_coeffs': array([ 0.8964528 , 0.49850679, 0.77766117, -0.21469201, -0.83229329,
-0.2429649 , -0.1551994 , -0.16094197, -0.66134098, 0.08932747,
0.85003664, -0.96698846, -0.24095418, 0.95928666, 0.40584951,
0.65621531, 0.34217628, -0.13498463, 0.6735691 , -0.91265536,
-0.05148727, 0.37550022, -0.66848646, -0.237498 , 0.13401424,
0.07616318, -0.76062642, 0.42537257, -0.78619399, 0.71017075,
0.62232671, 0.76061207, -0.25878416, -0.71957469, 0.75449224,
-0.2458921 , -0.64380881, -0.88398595, -0.73363227, -0.72695346,
-0.26828684, 0.64891776, -0.68961931, 0.72908515, -0.42343627,
0.24523088, -0.50362676, -0.80204453, -0.47411123, 0.06919655,
-0.85278136, -0.6872726 , -0.7074239 , -0.97904535, -0.2096503 ,
0.41902197, 0.25750279, 0.16304053, -0.37161017, -0.36869419,
0.87463671, -0.99804548, -0.2472362 , -0.99437107, 0.17233818,
-0.53704303, -0.70933562, -0.6216585 , -0.74211035, -0.11780913,
-0.33046545, -0.10765366, 0.09696944, -0.68235125, -0.78363395,
-0.53045928, -0.17417613, 0.691059 , -0.05228136, -0.38724882,
0.35066208, -0.5149922 , -0.77655213, 0.45167872, 0.96291537,
-0.74180878, 0.47324007, 0.07420529, 0.45694168, -0.19554454,
-0.08631478, 0.40328765, -0.82952522, 0.33224662, 0.06260759,
-0.54895729, 0.75930369, -0.3085233 , -0.79609509, 0.82898824,
0.53739623, 0.30328473, -0.90124674, -0.64246727, -0.21607528,
-0.04892372, 0.90662235, 0.71510085, -0.22509855, -0.26623875,
-0.641338 , -0.75124308, -0.8214267 , 0.5451419 , 0.02570617,
0.4018908 , -0.1126687 , -0.31593296, 0.0362656 , 0.76238948,
0.3919529 , -0.27760741, -0.10068226, -0.04583653, 0.74203014,
0.75065354, 0.54871431, -0.64430454, 0.53359048, -0.97495406,
-0.73663779, 0.33514719, 0.69741655, -0.53137909, 0.78693164,
0.17234047, 0.74777694, -0.1744733 , 0.7607344 , -0.86257238,
-0.17365085, -0.82280093, 0.7484344 , 0.88597422, -0.91898113,
-0.77001598, -0.95169786, -0.38264231, -0.81062648, -0.47486829,
0.64959473, 0.24822373, -0.02866922, 0.4009877 , 0.95769761,
0.49782551, -0.99525773, -0.53563913, 0.77753391, -0.09806273,
-0.20661118, 0.65396797, -0.97223146, 0.96095015, -0.66162998,
0.07860522, -0.2993302 , 0.03866622, 0.18522487, 0.54755709,
-0.08540239, 0.18920188, 0.55557113, 0.26725513, 0.54615184,
-0.18610811, -0.38184134, -0.4754358 , -0.54684329, -0.50546527,
0.85735608, -0.89882684, 0.16445413, -0.86161286, -0.13556939,
-0.8191506 , 0.95764276, 0.55347055, -0.23243567, -0.35112965,
0.34498401, 0.93611154, 0.51646511, -0.95735037, 0.28280915,
0.18386817, -0.54182348, 0.48228643, -0.53230718, -0.82028976]), 'sigma_log__': array(0.69110407)}
Logp initial evaluation results:
{'ell': np.float64(-1.0), 'eta': np.float64(-1.0), 'f_hsgp_coeffs': np.float64(-inf), 'sigma': np.float64(-1.3), 'y_obs': np.float64(-534.46)}
You can call `model.debug()` for more details.PyMC version information:
python 3.13.5
pymc 5.25.1
pytensor 2.31.7
numpy 2.2.6
Context for the issue:
RBF is the text book covariance function and is a go-to for exploring modelling data with GPs.