Skip to content

Commit 9024c2b

Browse files
authored
Fix JAX sampling funcs overwriting existing var's dims and coords (#6041)
* chore: tidy up existing test for idata kwargs * test: add to test to cause failure * feat: update coords and dims extracted from model with idata_kwargs in JAX samplers * doc: update docstring of JAX samplers with behavior for coords and dims * fix: remove pytest mark used in dev * style: fix typehints for older Python vers * style: minor styling changes per code review' - add details to comments about changing idata_kwargs - change test to comparing lists instead of sets - remove an explicit 'return None'
1 parent a229983 commit 9024c2b

File tree

2 files changed

+52
-13
lines changed

2 files changed

+52
-13
lines changed

pymc/sampling_jax.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,16 @@ def _get_batched_jittered_initial_points(
172172
return initial_points
173173

174174

175+
def _update_coords_and_dims(
176+
coords: Dict[str, Any], dims: Dict[str, Any], idata_kwargs: Dict[str, Any]
177+
) -> None:
178+
"""Update 'coords' and 'dims' dicts with values in 'idata_kwargs'."""
179+
if "coords" in idata_kwargs:
180+
coords.update(idata_kwargs.pop("coords"))
181+
if "dims" in idata_kwargs:
182+
dims.update(idata_kwargs.pop("dims"))
183+
184+
175185
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
176186
def _blackjax_inference_loop(
177187
seed,
@@ -264,7 +274,8 @@ def sample_blackjax_nuts(
264274
value for the ``log_likelihood`` key to indicate that the pointwise log
265275
likelihood should not be included in the returned object. Values for
266276
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
267-
the ``model`` argument if not provided in ``idata_kwargs``.
277+
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
278+
``dims`` are provided, they are used to update the inferred dictionaries.
268279
269280
Returns
270281
-------
@@ -376,6 +387,9 @@ def sample_blackjax_nuts(
376387
}
377388

378389
posterior = mcmc_samples
390+
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
391+
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
392+
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
379393
# Use 'partial' to set default arguments before passing 'idata_kwargs'
380394
to_trace = partial(
381395
az.from_dict,
@@ -470,7 +484,8 @@ def sample_numpyro_nuts(
470484
value for the ``log_likelihood`` key to indicate that the pointwise log
471485
likelihood should not be included in the returned object. Values for
472486
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
473-
the ``model`` argument if not provided in ``idata_kwargs``.
487+
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
488+
``dims`` are provided, they are used to update the inferred dictionaries.
474489
nuts_kwargs: dict, optional
475490
Keyword arguments for :func:`numpyro.infer.NUTS`.
476491
@@ -596,6 +611,9 @@ def sample_numpyro_nuts(
596611
}
597612

598613
posterior = mcmc_samples
614+
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
615+
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
616+
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
599617
# Use 'partial' to set default arguments before passing 'idata_kwargs'
600618
to_trace = partial(
601619
az.from_dict,
@@ -608,5 +626,4 @@ def sample_numpyro_nuts(
608626
attrs=make_attrs(attrs, library=numpyro),
609627
)
610628
az_trace = to_trace(posterior=posterior, **idata_kwargs)
611-
612629
return az_trace

pymc/tests/test_sampling_jax.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Any, Dict
1+
from typing import Any, Callable, Dict, Optional
22
from unittest import mock
33

44
import aesara
55
import aesara.tensor as at
6+
import arviz as az
67
import jax
78
import numpy as np
89
import pytest
@@ -159,11 +160,14 @@ def test_get_jaxified_logp():
159160
assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0))))
160161

161162

162-
@pytest.fixture
163-
def model_test_idata_kwargs(scope="module"):
164-
with pm.Model(coords={"x_coord": ["a", "b"], "x_coord2": [1, 2]}) as m:
163+
@pytest.fixture(scope="module")
164+
def model_test_idata_kwargs() -> pm.Model:
165+
with pm.Model(
166+
coords={"x_coord": ["a", "b"], "x_coord2": [1, 2], "z_coord": ["apple", "banana", "orange"]}
167+
) as m:
165168
x = pm.Normal("x", shape=(2,), dims=["x_coord"])
166-
y = pm.Normal("y", x, observed=[0, 0])
169+
_ = pm.Normal("y", x, observed=[0, 0])
170+
_ = pm.Normal("z", 0, 1, dims="z_coord")
167171
pm.ConstantData("constantdata", [1, 2, 3])
168172
pm.MutableData("mutabledata", 2)
169173
return m
@@ -190,7 +194,13 @@ def model_test_idata_kwargs(scope="module"):
190194
],
191195
)
192196
@pytest.mark.parametrize("postprocessing_backend", [None, "cpu"])
193-
def test_idata_kwargs(model_test_idata_kwargs, sampler, idata_kwargs, postprocessing_backend):
197+
def test_idata_kwargs(
198+
model_test_idata_kwargs: pm.Model,
199+
sampler: Callable[..., az.InferenceData],
200+
idata_kwargs: Dict[str, Any],
201+
postprocessing_backend: Optional[str],
202+
):
203+
idata: Optional[az.InferenceData] = None
194204
with model_test_idata_kwargs:
195205
idata = sampler(
196206
tune=50,
@@ -199,19 +209,31 @@ def test_idata_kwargs(model_test_idata_kwargs, sampler, idata_kwargs, postproces
199209
idata_kwargs=idata_kwargs,
200210
postprocessing_backend=postprocessing_backend,
201211
)
202-
assert "constantdata" in idata.constant_data
203-
assert "mutabledata" in idata.constant_data
212+
assert idata is not None
213+
const_data = idata.get("constant_data")
214+
assert const_data is not None
215+
assert "constantdata" in const_data
216+
assert "mutabledata" in const_data
204217

205218
if idata_kwargs.get("log_likelihood", True):
206219
assert "log_likelihood" in idata
207220
else:
208221
assert "log_likelihood" not in idata
209222

223+
posterior = idata.get("posterior")
224+
assert posterior is not None
210225
x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.RV_dims)["x"][0]
211-
assert idata.posterior.x.dims[-1] == x_dim_expected
226+
assert x_dim_expected is not None
227+
assert posterior["x"].dims[-1] == x_dim_expected
212228

213229
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[x_dim_expected]
214-
assert list(x_coords_expected) == list(idata.posterior.x.coords[x_dim_expected].values)
230+
assert x_coords_expected is not None
231+
assert list(x_coords_expected) == list(posterior["x"].coords[x_dim_expected].values)
232+
233+
assert posterior["z"].dims[2] == "z_coord"
234+
assert np.all(
235+
posterior["z"].coords["z_coord"].values == np.array(["apple", "banana", "orange"])
236+
)
215237

216238

217239
def test_get_batched_jittered_initial_points():

0 commit comments

Comments
 (0)