From 9ad9cf7892b8197277bd53f3f324b4d9d7a6f834 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 26 Feb 2025 13:17:59 +0100 Subject: [PATCH] Fix bug when reusing jax logp for initial point generation --- pymc/sampling/jax.py | 5 ++++- tests/sampling/test_jax.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index b2cbff9b6..7ae32e253 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -240,7 +240,10 @@ def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array: Wraps jaxified logp function to accept a dict of {model_variable: np.array} key:value pairs. """ - return logp_fn(point.values()) + # Because logp_fn is not jitted, we need to convert inputs to jax arrays, + # or some methods that are only available for jax arrays will fail + # such as x.at[indices].set(y) + return logp_fn([jax.numpy.asarray(v) for v in point.values()]) initial_points = _init_jitter( model, diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index ddec60e53..68092d9b9 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -352,6 +352,27 @@ def test_get_batched_jittered_initial_points(): assert np.all(ips[0][0] != ips[0][1]) +def test_get_batched_jittered_initial_points_set_subtensor(): + """Regression bug for issue described in + https://discourse.pymc.io/t/attributeerror-numpy-ndarray-object-has-no-attribute-at-when-sampling-lkj-cholesky-covariance-priors-for-multivariate-normal-models-example-with-numpyro-or-blackjax/16598/3 + + Which was caused by passing numpy arrays to a non-jitted logp function + """ + with pm.Model() as model: + # Set operation will use `x.at[1].set(100)` which is only available in JAX + x = pm.Normal("x", mu=[-100, -100]) + mu_y = x[1].set(100) + y = pm.Normal("y", mu=mu_y) + + logp_fn = get_jaxified_logp(model) + [x_ips, y_ips] = _get_batched_jittered_initial_points( + model, chains=3, initvals=None, logp_fn=logp_fn, jitter=True, random_seed=0 + ) + assert np.all(x_ips < -10) + assert np.all(y_ips[..., 0] < -10) + assert np.all(y_ips[..., 1] > 10) + + @pytest.mark.parametrize( "sampler", [