Skip to content

Commit 836bd99

Browse files
committed
fix: Fix jax backend with non-identifier variable names
1 parent 1cc8a94 commit 836bd99

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

python/nutpie/compile_pymc.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs):
297297
orig_logp_fn = logp_fn._fun
298298

299299
@jax.jit
300-
def logp_fn_jax_grad(x, **shared):
301-
return jax.value_and_grad(lambda x: orig_logp_fn(x, **shared)[0])(x)
300+
def logp_fn_jax_grad(x, *shared):
301+
return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x)
302302

303303
logp_fn = logp_fn_jax_grad
304304

@@ -317,9 +317,7 @@ def logp_fn_jax_grad(x, **shared):
317317

318318
def make_logp_func():
319319
def logp(x, **shared):
320-
logp, grad = logp_fn(
321-
x, **{name: shared[name] for name in logp_shared_names}
322-
)
320+
logp, grad = logp_fn(x, *[shared[name] for name in logp_shared_names])
323321
return float(logp), np.asarray(grad, dtype="float64", order="C")
324322

325323
return logp
@@ -330,9 +328,7 @@ def logp(x, **shared):
330328
def make_expand_func(seed1, seed2, chain):
331329
# TODO handle seeds
332330
def expand(x, **shared):
333-
values = expand_fn(
334-
x, **{name: shared[name] for name in expand_shared_names}
335-
)
331+
values = expand_fn(x, *[shared[name] for name in expand_shared_names])
336332
return {
337333
name: np.asarray(val, order="C", dtype=dtype).ravel()
338334
for name, val, dtype in zip(names, values, dtypes, strict=True)

tests/test_pymc.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,34 @@ def test_det(backend, gradient_backend):
127127
assert trace.posterior.b.shape[-1] == 2
128128

129129

130+
@parameterize_backends
131+
def test_non_identifier_names(backend, gradient_backend):
132+
with pm.Model() as model:
133+
a = pm.Uniform("a/b", shape=2)
134+
with pm.Model("foo"):
135+
c = pm.Data("c", np.array([2.0, 3.0]))
136+
pm.Deterministic("b", c * a)
137+
138+
compiled = nutpie.compile_pymc_model(
139+
model, backend=backend, gradient_backend=gradient_backend
140+
)
141+
trace = nutpie.sample(compiled, chains=1)
142+
assert trace.posterior["a/b"].shape[-1] == 2
143+
assert trace.posterior["foo::b"].shape[-1] == 2
144+
145+
130146
@parameterize_backends
131147
def test_pymc_model_shared(backend, gradient_backend):
132148
with pm.Model() as model:
133-
mu = pm.Data("mu", 0.1)
149+
mu = pm.Data("mu", -0.1)
134150
sigma = pm.Data("sigma", np.ones(3))
135151
pm.Normal("a", mu=mu, sigma=sigma, shape=3)
136152

137153
compiled = nutpie.compile_pymc_model(
138154
model, backend=backend, gradient_backend=gradient_backend
139155
)
140156
trace = nutpie.sample(compiled, chains=1, seed=1)
141-
np.testing.assert_allclose(trace.posterior.a.mean().values, 0.1, atol=0.05)
157+
np.testing.assert_allclose(trace.posterior.a.mean().values, -0.1, atol=0.05)
142158

143159
compiled2 = compiled.with_data(mu=10.0, sigma=3 * np.ones(3))
144160
trace2 = nutpie.sample(compiled2, chains=1, seed=1)

0 commit comments

Comments
 (0)