Skip to content

Commit ecc44c0

Browse files
committed
fix(pymc): allow data named x with unfrozen model
fixes #157
1 parent 46aaf8c commit ecc44c0

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

python/nutpie/compile_pymc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,8 @@ def logp_fn_jax_grad(x, *shared):
406406
seen.add(val)
407407

408408
def make_logp_func():
409-
def logp(x, **shared):
410-
logp, grad = logp_fn(x, *[shared[name] for name in logp_shared_names])
409+
def logp(_x, **shared):
410+
logp, grad = logp_fn(_x, *[shared[name] for name in logp_shared_names])
411411
return float(logp), np.asarray(grad, dtype="float64", order="C")
412412

413413
return logp
@@ -418,8 +418,8 @@ def logp(x, **shared):
418418

419419
def make_expand_func(seed1, seed2, chain):
420420
# TODO handle seeds
421-
def expand(x, **shared):
422-
values = expand_fn(x, *[shared[name] for name in expand_shared_names])
421+
def expand(_x, **shared):
422+
values = expand_fn(_x, *[shared[name] for name in expand_shared_names])
423423
return {
424424
name: np.asarray(val, order="C", dtype=dtype).ravel()
425425
for name, val, dtype in zip(names, values, dtypes, strict=True)

tests/test_pymc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ def test_pymc_model(backend, gradient_backend):
3131
trace.posterior.a # noqa: B018
3232

3333

34+
@pytest.mark.pymc
35+
@parameterize_backends
36+
def test_name_x(backend, gradient_backend):
37+
with pm.Model() as model:
38+
x = pm.Data("x", 1.0)
39+
a = pm.Normal("a", mu=x)
40+
pm.Deterministic("z", x * a)
41+
42+
compiled = nutpie.compile_pymc_model(
43+
model, backend=backend, gradient_backend=gradient_backend, freeze_model=False
44+
)
45+
trace = nutpie.sample(compiled, chains=1)
46+
trace.posterior.a # noqa: B018
47+
48+
3449
@pytest.mark.pymc
3550
def test_order_shared():
3651
a_val = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])

0 commit comments

Comments
 (0)