Skip to content

Commit 37be3fd

Browse files
committed
Updated import statements and linting fixes
1 parent 428e3eb commit 37be3fd

File tree

14 files changed

+74
-55
lines changed

14 files changed

+74
-55
lines changed

pymc_extras/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414
import logging
1515

16-
from pymc_experimental import gp, statespace, utils
17-
from pymc_experimental.distributions import *
18-
from pymc_experimental.inference.fit import fit
19-
from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize
20-
from pymc_experimental.model.model_api import as_model
21-
from pymc_experimental.version import __version__
16+
from pymc_extras import gp, statespace, utils
17+
from pymc_extras.distributions import *
18+
from pymc_extras.inference.fit import fit
19+
from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
20+
from pymc_extras.model.model_api import as_model
21+
from pymc_extras.version import __version__
2222

2323
_log = logging.getLogger("pmx")
2424

@@ -27,3 +27,5 @@
2727
if len(_log.handlers) == 0:
2828
handler = logging.StreamHandler()
2929
_log.addHandler(handler)
30+
31+
__all__ = ["fit", "MarginalModel", "marginalize", "as_model"]

pymc_extras/distributions/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,24 @@
1717
Experimental probability distributions for stochastic nodes in PyMC.
1818
"""
1919

20-
from pymc_experimental.distributions.continuous import Chi, GenExtreme, Maxwell
21-
from pymc_experimental.distributions.discrete import (
20+
from pymc_extras.distributions.continuous import Chi, GenExtreme, Maxwell
21+
from pymc_extras.distributions.discrete import (
2222
BetaNegativeBinomial,
2323
GeneralizedPoisson,
2424
Skellam,
2525
)
26-
from pymc_experimental.distributions.histogram_utils import histogram_approximation
27-
from pymc_experimental.distributions.multivariate import R2D2M2CP
28-
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
26+
from pymc_extras.distributions.histogram_utils import histogram_approximation
27+
from pymc_extras.distributions.multivariate import R2D2M2CP
28+
from pymc_extras.distributions.timeseries import DiscreteMarkovChain
2929

3030
__all__ = [
31-
"BetaNegativeBinomial",
31+
"Chi",
32+
"Maxwell",
3233
"DiscreteMarkovChain",
3334
"GeneralizedPoisson",
35+
"BetaNegativeBinomial",
3436
"GenExtreme",
3537
"R2D2M2CP",
3638
"Skellam",
3739
"histogram_approximation",
38-
"Chi",
39-
"Maxwell",
4040
]

pymc_extras/gp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
# limitations under the License.
1414

1515

16-
from pymc_experimental.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
16+
from pymc_extras.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
1717

1818
__all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"]

pymc_extras/inference/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
# limitations under the License.
1414

1515

16-
from pymc_experimental.inference.fit import fit
16+
from pymc_extras.inference.fit import fit
1717

1818
__all__ = ["fit"]

pymc_extras/model/marginal/marginal_model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@
2323

2424
__all__ = ["MarginalModel", "marginalize"]
2525

26-
from pymc_experimental.distributions import DiscreteMarkovChain
27-
from pymc_experimental.model.marginal.distributions import (
28-
MarginalDiscreteMarkovChainRV,
29-
MarginalFiniteDiscreteRV,
30-
get_domain_of_finite_discrete_rv,
31-
reduce_batch_dependent_logps,
26+
from pymc_extras.distributions import DiscreteMarkovChain
27+
from pymc_extras.model.marginal.distributions import (
28+
FiniteDiscreteMarginalRV,
29+
MarginalRV,
30+
MarginalizedRV,
3231
)
33-
from pymc_experimental.model.marginal.graph_analysis import (
34-
find_conditional_dependent_rvs,
35-
find_conditional_input_rvs,
36-
is_conditional_dependent,
37-
subgraph_batch_dim_connection,
32+
from pymc_extras.model.marginal.graph_analysis import (
33+
get_support_axes,
34+
get_support_shape,
35+
get_support_size,
36+
get_support_values,
37+
get_variable_support,
3838
)
3939

4040
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]

pymc_extras/statespace/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
from pymc_experimental.statespace.core.compile import compile_statespace
2-
from pymc_experimental.statespace.models import structural
3-
from pymc_experimental.statespace.models.ETS import BayesianETS
4-
from pymc_experimental.statespace.models.SARIMAX import BayesianSARIMA
5-
from pymc_experimental.statespace.models.VARMAX import BayesianVARMAX
1+
from pymc_extras.statespace.core.compile import compile_statespace
2+
from pymc_extras.statespace.models import structural
3+
from pymc_extras.statespace.models.ETS import BayesianETS
4+
from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA
5+
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX
66

7-
__all__ = ["structural", "BayesianSARIMA", "BayesianVARMAX", "BayesianETS", "compile_statespace"]
7+
__all__ = [
8+
"compile_statespace",
9+
"structural",
10+
"BayesianETS",
11+
"BayesianSARIMA",
12+
"BayesianVARMAX",
13+
]

pymc_extras/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515

16-
from pymc_experimental.utils import prior, spline
17-
from pymc_experimental.utils.linear_cg import linear_cg
16+
from pymc_extras.utils import prior, spline
17+
from pymc_extras.utils.linear_cg import linear_cg
1818

1919
__all__ = (
2020
"linear_cg",

tests/test_blackjax_smc.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,23 @@
2323
jax = pytest.importorskip("jax")
2424
pytest.importorskip("blackjax")
2525

26-
from pymc_experimental.inference.smc.sampling import (
27-
arviz_from_particles,
28-
blackjax_particles_from_pymc_population,
29-
get_jaxified_loglikelihood,
30-
get_jaxified_logprior,
31-
sample_smc_blackjax,
26+
from pymc_extras.inference.smc.sampling import (
27+
SMC_KERNEL_DEFAULTS,
28+
SMC_KERNEL_NAMES,
29+
SMC_KERNEL_PARAMS,
30+
SMC_KERNEL_PARAMS_DEFAULTS,
31+
SMC_KERNEL_PARAMS_NAMES,
32+
SMC_KERNEL_PARAMS_TYPES,
33+
SMC_KERNEL_TYPES,
34+
SMC_KERNEL_VALID_PARAMS,
35+
SMC_KERNEL_VALID_PARAMS_TYPES,
36+
SMC_KERNELS,
37+
SMC_KERNELS_PARAMS,
38+
SMC_KERNELS_PARAMS_DEFAULTS,
39+
SMC_KERNELS_PARAMS_NAMES,
40+
SMC_KERNELS_PARAMS_TYPES,
41+
SMC_KERNELS_VALID_PARAMS,
42+
SMC_KERNELS_VALID_PARAMS_TYPES,
3243
)
3344

3445

tests/test_find_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import pytensor.tensor as pt
44
import pytest
55

6-
from pymc_experimental.inference.find_map import (
7-
GradientBackend,
6+
from pymc_extras.inference.find_map import (
87
find_MAP,
9-
scipy_optimize_funcs_from_loss,
8+
find_MAP_scipy,
9+
find_MAP_pytensor,
1010
)
1111

1212
pytest.importorskip("jax")

tests/test_histogram_approximation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pymc as pm
1818
import pytest
1919

20-
import pymc_experimental as pmx
20+
import pymc_extras as pmx
2121

2222

2323
@pytest.mark.parametrize("use_dask", [True, False], ids="dask={}".format)

0 commit comments

Comments
 (0)