35
35
from pytensor .graph .fg import FunctionGraph
36
36
from pytensor .graph .replace import clone_replace
37
37
from pytensor .link .jax .dispatch import jax_funcify
38
- from pytensor .raise_op import Assert
39
38
from pytensor .tensor import TensorVariable
40
39
from pytensor .tensor .random .type import RandomType
41
40
47
46
)
48
47
from pymc .distributions .multivariate import PosDefMatrix
49
48
from pymc .initial_point import StartDict
50
- from pymc .logprob .utils import CheckParameterValue
51
49
from pymc .sampling .mcmc import _init_jitter
52
50
from pymc .stats .convergence import log_warnings , run_convergence_checks
53
51
from pymc .util import (
71
69
)
72
70
73
71
74
- @jax_funcify .register (Assert )
75
- @jax_funcify .register (CheckParameterValue )
76
- def jax_funcify_Assert (op , ** kwargs ):
77
- # Jax does not allow assert whose values aren't known during JIT compilation
78
- # within it's JIT-ed code. Hence we need to make a simple pass through
79
- # version of the Assert Op.
80
- # https://github.com/google/jax/issues/2273#issuecomment-589098722
81
- def assert_fn (value , * inps ):
82
- return value
83
-
84
- return assert_fn
85
-
86
-
87
72
@jax_funcify .register (PosDefMatrix )
88
73
def jax_funcify_PosDefMatrix (op , ** kwargs ):
89
74
def posdefmatrix_fn (value , * inps ):
@@ -520,8 +505,6 @@ def sample_jax_nuts(
520
505
keep_untransformed : bool = False ,
521
506
chain_method : Literal ["parallel" , "vectorized" ] = "parallel" ,
522
507
postprocessing_backend : Literal ["cpu" , "gpu" ] | None = None ,
523
- postprocessing_vectorize : Literal ["vmap" , "scan" ] | None = None ,
524
- postprocessing_chunks = None ,
525
508
idata_kwargs : dict | None = None ,
526
509
compute_convergence_checks : bool = True ,
527
510
nuts_sampler : Literal ["numpyro" , "blackjax" ],
@@ -593,25 +576,6 @@ def sample_jax_nuts(
593
576
with their respective sample stats and pointwise log likeihood values (unless
594
577
skipped with ``idata_kwargs``).
595
578
"""
596
- if postprocessing_chunks is not None :
597
- import warnings
598
-
599
- warnings .warn (
600
- "postprocessing_chunks is deprecated due to being unstable, "
601
- "using postprocessing_vectorize='scan' instead" ,
602
- DeprecationWarning ,
603
- )
604
-
605
- if postprocessing_vectorize is not None :
606
- import warnings
607
-
608
- warnings .warn (
609
- 'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.' ,
610
- FutureWarning ,
611
- )
612
- else :
613
- postprocessing_vectorize = "vmap"
614
-
615
579
model = modelcontext (model )
616
580
617
581
if var_names is not None :
@@ -674,7 +638,6 @@ def sample_jax_nuts(
674
638
model ,
675
639
raw_mcmc_samples ,
676
640
backend = postprocessing_backend ,
677
- postprocessing_vectorize = postprocessing_vectorize ,
678
641
)
679
642
else :
680
643
log_likelihood = None
@@ -684,7 +647,6 @@ def sample_jax_nuts(
684
647
jax_fn ,
685
648
raw_mcmc_samples ,
686
649
postprocessing_backend = postprocessing_backend ,
687
- postprocessing_vectorize = postprocessing_vectorize ,
688
650
donate_samples = True ,
689
651
)
690
652
del raw_mcmc_samples
@@ -704,8 +666,8 @@ def sample_jax_nuts(
704
666
dims .update (idata_kwargs .pop ("dims" ))
705
667
706
668
# Use 'partial' to set default arguments before passing 'idata_kwargs'
707
- to_trace = partial (
708
- az . from_dict ,
669
+ idata = az . from_dict (
670
+ posterior = mcmc_samples ,
709
671
log_likelihood = log_likelihood ,
710
672
observed_data = find_observations (model ),
711
673
constant_data = find_constants (model ),
@@ -714,14 +676,13 @@ def sample_jax_nuts(
714
676
dims = dims ,
715
677
attrs = make_attrs (attrs , library = library ),
716
678
posterior_attrs = make_attrs (attrs , library = library ),
679
+ ** idata_kwargs ,
717
680
)
718
- az_trace = to_trace (posterior = mcmc_samples , ** idata_kwargs )
719
681
720
682
if compute_convergence_checks :
721
- warns = run_convergence_checks (az_trace , model )
722
- log_warnings (warns )
683
+ log_warnings (run_convergence_checks (idata , model ))
723
684
724
- return az_trace
685
+ return idata
725
686
726
687
727
688
sample_numpyro_nuts = partial (sample_jax_nuts , nuts_sampler = "numpyro" )
0 commit comments