Skip to content

Commit bb5a866

Browse files
committed
fix(numba): non-contiguous shared variable
Shared variables (ie pm.Data) that were non-contiguous could lead to incorrect results in the pymc numba backend. We now ensure that they are always c-contiguous by copying if they are not.
1 parent 00e6ced commit bb5a866

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

python/nutpie/compile_pymc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def with_data(self, **updates):
131131
if name not in shared_data:
132132
raise KeyError(f"Unknown shared variable: {name}")
133133
old_val = shared_data[name]
134-
new_val = np.asarray(new_val, dtype=old_val.dtype).copy()
134+
new_val = np.array(new_val, dtype=old_val.dtype, order="C", copy=True)
135135
new_val.flags.writeable = False
136136
if old_val.ndim != new_val.ndim:
137137
raise ValueError(
@@ -256,7 +256,7 @@ def _compile_pymc_model_numba(
256256
for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]:
257257
if val.name in shared_data and val not in seen:
258258
raise ValueError(f"Shared variables must have unique names: {val.name}")
259-
shared_data[val.name] = val.get_value()
259+
shared_data[val.name] = np.array(val.get_value(), order="C", copy=True)
260260
shared_vars[val.name] = val
261261
seen.add(val)
262262

tests/test_pymc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,38 @@ def test_pymc_model(backend, gradient_backend):
3131
trace.posterior.a # noqa: B018
3232

3333

34+
@pytest.mark.pymc
35+
def test_order_shared():
36+
a_val = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
37+
with pm.Model() as model:
38+
a = pm.Data("a", np.copy(a_val, order="C"))
39+
b = pm.Normal("b", shape=(2, 5))
40+
pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1))
41+
42+
compiled = nutpie.compile_pymc_model(model, backend="numba")
43+
trace = nutpie.sample(compiled)
44+
np.testing.assert_allclose(
45+
(
46+
trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :]
47+
).sum(-1),
48+
trace.posterior.c.values,
49+
)
50+
51+
with pm.Model() as model:
52+
a = pm.Data("a", np.copy(a_val, order="F"))
53+
b = pm.Normal("b", shape=(2, 5))
54+
pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1))
55+
56+
compiled = nutpie.compile_pymc_model(model, backend="numba")
57+
trace = nutpie.sample(compiled)
58+
np.testing.assert_allclose(
59+
(
60+
trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :]
61+
).sum(-1),
62+
trace.posterior.c.values,
63+
)
64+
65+
3466
@pytest.mark.pymc
3567
@parameterize_backends
3668
def test_low_rank(backend, gradient_backend):

0 commit comments

Comments
 (0)