Skip to content

Commit 321e391

Browse files
committed
Draft external sampler API
1 parent 011fb35 commit 321e391

File tree

6 files changed

+345
-254
lines changed

6 files changed

+345
-254
lines changed

pymc/sampling/jax.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from pytensor.graph.fg import FunctionGraph
3636
from pytensor.graph.replace import clone_replace
3737
from pytensor.link.jax.dispatch import jax_funcify
38-
from pytensor.raise_op import Assert
3938
from pytensor.tensor import TensorVariable
4039
from pytensor.tensor.random.type import RandomType
4140

@@ -47,7 +46,6 @@
4746
)
4847
from pymc.distributions.multivariate import PosDefMatrix
4948
from pymc.initial_point import StartDict
50-
from pymc.logprob.utils import CheckParameterValue
5149
from pymc.sampling.mcmc import _init_jitter
5250
from pymc.stats.convergence import log_warnings, run_convergence_checks
5351
from pymc.util import (
@@ -71,19 +69,6 @@
7169
)
7270

7371

74-
@jax_funcify.register(Assert)
75-
@jax_funcify.register(CheckParameterValue)
76-
def jax_funcify_Assert(op, **kwargs):
77-
# Jax does not allow assert whose values aren't known during JIT compilation
78-
# within it's JIT-ed code. Hence we need to make a simple pass through
79-
# version of the Assert Op.
80-
# https://github.com/google/jax/issues/2273#issuecomment-589098722
81-
def assert_fn(value, *inps):
82-
return value
83-
84-
return assert_fn
85-
86-
8772
@jax_funcify.register(PosDefMatrix)
8873
def jax_funcify_PosDefMatrix(op, **kwargs):
8974
def posdefmatrix_fn(value, *inps):
@@ -520,8 +505,6 @@ def sample_jax_nuts(
520505
keep_untransformed: bool = False,
521506
chain_method: Literal["parallel", "vectorized"] = "parallel",
522507
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
523-
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
524-
postprocessing_chunks=None,
525508
idata_kwargs: dict | None = None,
526509
compute_convergence_checks: bool = True,
527510
nuts_sampler: Literal["numpyro", "blackjax"],
@@ -593,25 +576,6 @@ def sample_jax_nuts(
593576
with their respective sample stats and pointwise log likeihood values (unless
594577
skipped with ``idata_kwargs``).
595578
"""
596-
if postprocessing_chunks is not None:
597-
import warnings
598-
599-
warnings.warn(
600-
"postprocessing_chunks is deprecated due to being unstable, "
601-
"using postprocessing_vectorize='scan' instead",
602-
DeprecationWarning,
603-
)
604-
605-
if postprocessing_vectorize is not None:
606-
import warnings
607-
608-
warnings.warn(
609-
'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.',
610-
FutureWarning,
611-
)
612-
else:
613-
postprocessing_vectorize = "vmap"
614-
615579
model = modelcontext(model)
616580

617581
if var_names is not None:
@@ -674,7 +638,6 @@ def sample_jax_nuts(
674638
model,
675639
raw_mcmc_samples,
676640
backend=postprocessing_backend,
677-
postprocessing_vectorize=postprocessing_vectorize,
678641
)
679642
else:
680643
log_likelihood = None
@@ -684,7 +647,6 @@ def sample_jax_nuts(
684647
jax_fn,
685648
raw_mcmc_samples,
686649
postprocessing_backend=postprocessing_backend,
687-
postprocessing_vectorize=postprocessing_vectorize,
688650
donate_samples=True,
689651
)
690652
del raw_mcmc_samples
@@ -704,8 +666,8 @@ def sample_jax_nuts(
704666
dims.update(idata_kwargs.pop("dims"))
705667

706668
# Use 'partial' to set default arguments before passing 'idata_kwargs'
707-
to_trace = partial(
708-
az.from_dict,
669+
idata = az.from_dict(
670+
posterior=mcmc_samples,
709671
log_likelihood=log_likelihood,
710672
observed_data=find_observations(model),
711673
constant_data=find_constants(model),
@@ -714,14 +676,13 @@ def sample_jax_nuts(
714676
dims=dims,
715677
attrs=make_attrs(attrs, library=library),
716678
posterior_attrs=make_attrs(attrs, library=library),
679+
**idata_kwargs,
717680
)
718-
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)
719681

720682
if compute_convergence_checks:
721-
warns = run_convergence_checks(az_trace, model)
722-
log_warnings(warns)
683+
log_warnings(run_convergence_checks(idata, model))
723684

724-
return az_trace
685+
return idata
725686

726687

727688
sample_numpyro_nuts = partial(sample_jax_nuts, nuts_sampler="numpyro")

0 commit comments

Comments
 (0)