Skip to content

Commit 2bdb27c

Browse files
committed
Various fixes
1 parent 5eb2061 commit 2bdb27c

File tree

10 files changed

+16
-41
lines changed

10 files changed

+16
-41
lines changed

pymc_extras/distributions/histogram_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import numpy as np
1717
import pymc as pm
18-
import pymc_extras as pmx
1918

2019
from numpy.typing import ArrayLike
2120

pymc_extras/model/marginal/distributions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def reduce_batch_dependent_logps(
9292
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
9393
9494
"""
95-
from pymc_extras.model.marginal.graph_analysis import get_support_axes
9695

9796
reduced_logps = []
9897
for dependent_op, dependent_logp, dependent_dims_connection in zip(

pymc_extras/model/model_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from functools import wraps
22

33
from pymc import Model
4-
import pymc_extras as pmx
54

65

76
def as_model(*model_args, **model_kwargs):

pymc_extras/statespace/models/ETS.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
1111
from pymc_extras.statespace.models.utilities import make_default_coords
1212
from pymc_extras.statespace.utils.constants import (
13-
JITTER_DEFAULT,
14-
LONG_MATRIX_NAMES,
15-
MISSING_FILL,
16-
SHORT_NAME_TO_LONG,
1713
ALL_STATE_AUX_DIM,
1814
ALL_STATE_DIM,
1915
ETS_SEASONAL_DIM,

pymc_extras/statespace/models/VARMAX.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
OBS_STATE_DIM,
1919
SHOCK_AUX_DIM,
2020
SHOCK_DIM,
21-
JITTER_DEFAULT,
22-
LONG_MATRIX_NAMES,
23-
MISSING_FILL,
24-
SHORT_NAME_TO_LONG,
2521
)
26-
import pymc_extras.statespace as pmss
2722

2823
floatX = pytensor.config.floatX
2924

pymc_extras/statespace/models/utilities.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
SHOCK_AUX_DIM,
1212
SHOCK_DIM,
1313
VECTOR_VALUED,
14-
JITTER_DEFAULT,
15-
MISSING_FILL,
16-
SHORT_NAME_TO_LONG,
1714
)
1815

1916

tests/model/marginal/test_graph_analysis.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from pytensor.tensor.type_other import NoneTypeT
66

77
from pymc_extras.model.marginal.graph_analysis import (
8-
find_conditional_dependent_rvs,
9-
find_conditional_input_rvs,
108
is_conditional_dependent,
119
subgraph_batch_dim_connection,
1210
)

tests/test_blackjax_smc.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,17 @@
2020
from numpy import dtype
2121
from xarray.core.utils import Frozen
2222

23-
jax = pytest.importorskip("jax")
24-
pytest.importorskip("blackjax")
25-
2623
from pymc_extras.inference.smc.sampling import (
27-
SMC_KERNEL_DEFAULTS,
28-
SMC_KERNEL_NAMES,
29-
SMC_KERNEL_PARAMS,
30-
SMC_KERNEL_PARAMS_DEFAULTS,
31-
SMC_KERNEL_PARAMS_NAMES,
32-
SMC_KERNEL_PARAMS_TYPES,
33-
SMC_KERNEL_TYPES,
34-
SMC_KERNEL_VALID_PARAMS,
35-
SMC_KERNEL_VALID_PARAMS_TYPES,
36-
SMC_KERNELS,
37-
SMC_KERNELS_PARAMS,
38-
SMC_KERNELS_PARAMS_DEFAULTS,
39-
SMC_KERNELS_PARAMS_NAMES,
40-
SMC_KERNELS_PARAMS_TYPES,
41-
SMC_KERNELS_VALID_PARAMS,
42-
SMC_KERNELS_VALID_PARAMS_TYPES,
24+
arviz_from_particles,
25+
blackjax_particles_from_pymc_population,
26+
get_jaxified_loglikelihood,
27+
get_jaxified_logprior,
28+
sample_smc_blackjax,
4329
)
4430

31+
jax = pytest.importorskip("jax")
32+
pytest.importorskip("blackjax")
33+
4534

4635
def two_gaussians_model():
4736
n = 4

tests/test_find_map.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
from typing import Literal
2+
13
import numpy as np
24
import pymc as pm
35
import pytensor.tensor as pt
46
import pytest
57

8+
from pymc_extras.find_map import scipy_optimize_funcs_from_loss
69
from pymc_extras.inference.find_map import (
710
find_MAP,
8-
find_MAP_scipy,
9-
find_MAP_pytensor,
1011
)
1112

1213
pytest.importorskip("jax")
@@ -18,6 +19,10 @@ def rng():
1819
return np.random.default_rng(seed)
1920

2021

22+
# Define GradientBackend type alias
23+
GradientBackend = Literal["jax", "pytensor"]
24+
25+
2126
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
2227
def test_jax_functions_from_graph(gradient_backend: GradientBackend):
2328
x = pt.tensor("x", shape=(2,))

tests/test_laplace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
143143
transform_samples=transform_samples,
144144
)
145145

146-
idata = sample_laplace(
147-
mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples
148-
)
146+
idata = sample_laplace(mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples)
149147

150148
np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5)
151149
np.testing.assert_allclose(

0 commit comments

Comments
 (0)