Skip to content

Commit ca04890

Browse files
Make mypy happy
1 parent 6239b3d commit ca04890

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pymc/sampling/jax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868
"sample_numpyro_nuts",
6969
)
7070

71+
JaxNutsSampler = Literal["numpyro", "blackjax"]
72+
7173

7274
@jax_funcify.register(Assert)
7375
@jax_funcify.register(CheckParameterValue)
@@ -486,7 +488,7 @@ def sample_jax_nuts(
486488
postprocessing_chunks=None,
487489
idata_kwargs: dict | None = None,
488490
compute_convergence_checks: bool = True,
489-
nuts_sampler: Literal["numpyro", "blackjax"],
491+
nuts_sampler: JaxNutsSampler,
490492
) -> az.InferenceData:
491493
"""
492494
Draw samples from the posterior using a jax NUTS method.

pymc/sampling/mcmc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,11 @@
8282

8383
Step: TypeAlias = BlockedStep | CompoundStep
8484

85-
ExternalNutsSampler = ["nutpie", "numpyro", "blackjax"]
85+
ExternalNutsSampler = Literal["nutpie", "numpyro", "blackjax"]
8686
NutsSampler = Literal["pymc"] | ExternalNutsSampler
87-
8887
NutpieBackend = Literal["numba", "jax"]
88+
89+
8990
NUTPIE_BACKENDS = get_args(NutpieBackend)
9091
NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba")
9192

@@ -381,6 +382,10 @@ def extract_backend(string: str) -> NutpieBackend:
381382
elif sampler in ("numpyro", "blackjax"):
382383
import pymc.sampling.jax as pymc_jax
383384

385+
from pymc.sampling.jax import JaxNutsSampler
386+
387+
sampler = cast(JaxNutsSampler, sampler)
388+
384389
idata = pymc_jax.sample_jax_nuts(
385390
draws=draws,
386391
tune=tune,

0 commit comments

Comments
 (0)