File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed
Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change 6868 "sample_numpyro_nuts" ,
6969)
7070
71+ JaxNutsSampler = Literal ["numpyro" , "blackjax" ]
72+
7173
7274@jax_funcify .register (Assert )
7375@jax_funcify .register (CheckParameterValue )
@@ -486,7 +488,7 @@ def sample_jax_nuts(
486488 postprocessing_chunks = None ,
487489 idata_kwargs : dict | None = None ,
488490 compute_convergence_checks : bool = True ,
489- nuts_sampler : Literal [ "numpyro" , "blackjax" ] ,
491+ nuts_sampler : JaxNutsSampler ,
490492) -> az .InferenceData :
491493 """
492494 Draw samples from the posterior using a jax NUTS method.
Original file line number Diff line number Diff line change 8282
8383Step : TypeAlias = BlockedStep | CompoundStep
8484
85- ExternalNutsSampler = ["nutpie" , "numpyro" , "blackjax" ]
85+ ExternalNutsSampler = Literal ["nutpie" , "numpyro" , "blackjax" ]
8686NutsSampler = Literal ["pymc" ] | ExternalNutsSampler
87-
8887NutpieBackend = Literal ["numba" , "jax" ]
88+
89+
8990NUTPIE_BACKENDS = get_args (NutpieBackend )
9091NUTPIE_DEFAULT_BACKEND = cast (NutpieBackend , "numba" )
9192
@@ -381,6 +382,10 @@ def extract_backend(string: str) -> NutpieBackend:
381382 elif sampler in ("numpyro" , "blackjax" ):
382383 import pymc .sampling .jax as pymc_jax
383384
385+ from pymc .sampling .jax import JaxNutsSampler
386+
387+ sampler = cast (JaxNutsSampler , sampler )
388+
384389 idata = pymc_jax .sample_jax_nuts (
385390 draws = draws ,
386391 tune = tune ,
You can’t perform that action at this time.
0 commit comments