Skip to content

Commit 14870b7

Browse files
committed
[DO NOT MERGE] Test statespace models in numba backend
1 parent 04b838a commit 14870b7

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

tests/statespace/test_coord_assignment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def make_model(index):
137137
with pytest.warns(UserWarning, match="No time index found on the supplied data"):
138138
ss_mod.build_statespace_graph(
139139
a["A"],
140-
mode="JAX",
141140
)
142141
return model
143142

tests/statespace/test_statespace_JAX.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
)
2323
from tests.statespace.utilities.test_helpers import load_nile_test_data
2424

25-
pytest.importorskip("jax")
26-
pytest.importorskip("numpyro")
25+
pytest.importorskip("numba")
26+
# pytest.importorskip("numpyro")
2727

2828

2929
floatX = pytensor.config.floatX
@@ -38,7 +38,7 @@ def pymc_mod(ss_mod):
3838
zeta = pm.Deterministic("zeta", 1 - rho)
3939

4040
ss_mod.build_statespace_graph(
41-
data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True
41+
data=nile, mode="NUMBA", save_kalman_filter_outputs_in_idata=True
4242
)
4343
names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"]
4444
for name, matrix in zip(names, ss_mod.unpack_statespace()):
@@ -62,7 +62,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
6262
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
6363

6464
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
65-
exog_ss_mod.build_statespace_graph(y, mode="JAX")
65+
exog_ss_mod.build_statespace_graph(y, mode="NUMBA")
6666

6767
return m
6868

@@ -77,12 +77,13 @@ def idata(pymc_mod, rng):
7777
tune=1,
7878
chains=1,
7979
random_seed=rng,
80-
nuts_sampler="numpyro",
80+
nuts_sampler="pymc",
81+
compile_kwargs={"mode": "NUMBA"},
8182
progressbar=False,
8283
)
8384
with freeze_dims_and_data(pymc_mod):
8485
idata_prior = pm.sample_prior_predictive(
85-
samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
86+
samples=10, random_seed=rng, compile_kwargs={"mode": "NUMBA"}
8687
)
8788

8889
idata.extend(idata_prior)
@@ -100,12 +101,13 @@ def idata_exog(exog_pymc_mod, rng):
100101
tune=1,
101102
chains=1,
102103
random_seed=rng,
103-
nuts_sampler="numpyro",
104+
nuts_sampler="pymc",
105+
compile_kwargs={"mode": "NUMBA"},
104106
progressbar=False,
105107
)
106108
with freeze_dims_and_data(pymc_mod):
107109
idata_prior = pm.sample_prior_predictive(
108-
samples=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
110+
samples=10, random_seed=rng, compile_kwargs={"mode": "NUMBA"}
109111
)
110112

111113
idata.extend(idata_prior)
@@ -121,7 +123,7 @@ def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata):
121123
@pytest.mark.parametrize("group", ["prior", "posterior"])
122124
@pytest.mark.parametrize("kind", ["conditional", "unconditional"])
123125
def test_sampling_methods(group, kind, ss_mod, idata, rng):
124-
assert ss_mod._fit_mode == "JAX"
126+
assert ss_mod._fit_mode == "NUMBA"
125127

126128
f = getattr(ss_mod, f"sample_{kind}_{group}")
127129
with pytest.warns(UserWarning, match="The RandomType SharedVariables"):

0 commit comments

Comments
 (0)