2222)
2323from tests .statespace .utilities .test_helpers import load_nile_test_data
2424
25- pytest .importorskip ("jax " )
26- pytest .importorskip ("numpyro" )
25+ pytest .importorskip ("numba " )
26+ # pytest.importorskip("numpyro")
2727
2828
2929floatX = pytensor .config .floatX
@@ -38,7 +38,7 @@ def pymc_mod(ss_mod):
3838 zeta = pm .Deterministic ("zeta" , 1 - rho )
3939
4040 ss_mod .build_statespace_graph (
41- data = nile , mode = "JAX " , save_kalman_filter_outputs_in_idata = True
41+ data = nile , mode = "NUMBA " , save_kalman_filter_outputs_in_idata = True
4242 )
4343 names = ["x0" , "P0" , "c" , "d" , "T" , "Z" , "R" , "H" , "Q" ]
4444 for name , matrix in zip (names , ss_mod .unpack_statespace ()):
@@ -62,7 +62,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
6262 beta_exog = pm .Normal ("beta_exog" , dims = ["exog_state" ])
6363
6464 sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["trend_shock" ])
65- exog_ss_mod .build_statespace_graph (y , mode = "JAX " )
65+ exog_ss_mod .build_statespace_graph (y , mode = "NUMBA " )
6666
6767 return m
6868
@@ -77,12 +77,13 @@ def idata(pymc_mod, rng):
7777 tune = 1 ,
7878 chains = 1 ,
7979 random_seed = rng ,
80- nuts_sampler = "numpyro" ,
80+ nuts_sampler = "pymc" ,
81+ compile_kwargs = {"mode" : "NUMBA" },
8182 progressbar = False ,
8283 )
8384 with freeze_dims_and_data (pymc_mod ):
8485 idata_prior = pm .sample_prior_predictive (
85- samples = 10 , random_seed = rng , compile_kwargs = {"mode" : "JAX " }
86+ samples = 10 , random_seed = rng , compile_kwargs = {"mode" : "NUMBA " }
8687 )
8788
8889 idata .extend (idata_prior )
@@ -100,12 +101,13 @@ def idata_exog(exog_pymc_mod, rng):
100101 tune = 1 ,
101102 chains = 1 ,
102103 random_seed = rng ,
103- nuts_sampler = "numpyro" ,
104+ nuts_sampler = "pymc" ,
105+ compile_kwargs = {"mode" : "NUMBA" },
104106 progressbar = False ,
105107 )
106108 with freeze_dims_and_data (pymc_mod ):
107109 idata_prior = pm .sample_prior_predictive (
108- samples = 10 , random_seed = rng , compile_kwargs = {"mode" : "JAX " }
110+ samples = 10 , random_seed = rng , compile_kwargs = {"mode" : "NUMBA " }
109111 )
110112
111113 idata .extend (idata_prior )
@@ -121,7 +123,7 @@ def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
121123@pytest .mark .parametrize ("group" , ["prior" , "posterior" ])
122124@pytest .mark .parametrize ("kind" , ["conditional" , "unconditional" ])
123125def test_sampling_methods (group , kind , ss_mod , idata , rng ):
124- assert ss_mod ._fit_mode == "JAX "
126+ assert ss_mod ._fit_mode == "NUMBA "
125127
126128 f = getattr (ss_mod , f"sample_{ kind } _{ group } " )
127129 with pytest .warns (UserWarning , match = "The RandomType SharedVariables" ):
0 commit comments