Skip to content

Commit e6bf0e3

Browse files
Additional fixes requested by ruff
1 parent 9aee194 commit e6bf0e3

File tree

16 files changed

+49
-33
lines changed

16 files changed

+49
-33
lines changed

pymc_experimental/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
# limitations under the License.
1414
import logging
1515

16+
from pymc_experimental import distributions, gp, statespace, utils
17+
from pymc_experimental.inference.fit import fit
18+
from pymc_experimental.model.marginal_model import MarginalModel
19+
from pymc_experimental.model.model_api import as_model
1620
from pymc_experimental.version import __version__
1721

1822
_log = logging.getLogger("pmx")
@@ -23,7 +27,14 @@
2327
handler = logging.StreamHandler()
2428
_log.addHandler(handler)
2529

26-
from pymc_experimental import distributions, gp, statespace, utils
27-
from pymc_experimental.inference.fit import fit
28-
from pymc_experimental.model.marginal_model import MarginalModel
29-
from pymc_experimental.model.model_api import as_model
30+
31+
__all__ = [
32+
"distributions",
33+
"gp",
34+
"statespace",
35+
"utils",
36+
"fit",
37+
"MarginalModel",
38+
"as_model",
39+
"__version__",
40+
]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
2+
3+
__all__ = ["R2D2M2CP"]

pymc_experimental/gp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414

1515

1616
from pymc_experimental.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
17+
18+
__all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"]

pymc_experimental/inference/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414

1515

1616
from pymc_experimental.inference.fit import fit
17+
18+
__all__ = ["fit"]

pymc_experimental/inference/fit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from importlib.util import find_spec
1415

1516

1617
def fit(method, **kwargs):
@@ -30,10 +31,8 @@ def fit(method, **kwargs):
3031
arviz.InferenceData
3132
"""
3233
if method == "pathfinder":
33-
try:
34-
import blackjax
35-
except ImportError as exc:
36-
raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc
34+
if find_spec("blackjax") is None:
35+
raise RuntimeError("Need BlackJAX to use `pathfinder`")
3736

3837
from pymc_experimental.inference.pathfinder import fit_pathfinder
3938

pymc_experimental/model/marginal_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytensor import Mode, scan
2222
from pytensor.compile import SharedVariable
2323
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
24+
from pytensor.graph.basic import graph_inputs
2425
from pytensor.graph.replace import graph_replace, vectorize_graph
2526
from pytensor.scan import map as scan_map
2627
from pytensor.tensor import TensorType, TensorVariable
@@ -638,9 +639,6 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
638639
return True
639640

640641

641-
from pytensor.graph.basic import graph_inputs
642-
643-
644642
def collect_shared_vars(outputs, blockers):
645643
return [
646644
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ def vip_reparametrize(
419419
lambda_names.append(lam.name)
420420
toposort_replace(fmodel, replacements, reverse=True)
421421
reparam_model = model_from_fgraph(fmodel)
422-
model_lambdas = {n: reparam_model[l] for l, n in zip(lambda_names, var_names)}
422+
model_lambdas = {
423+
var_name: reparam_model[lambda_name]
424+
for lambda_name, var_name in zip(lambda_names, var_names)
425+
}
423426
vip = VIP(model_lambdas)
424427
return reparam_model, vip

pymc_experimental/statespace/core/representation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import copy
22

3-
from typing import Union
4-
53
import numpy as np
64
import pytensor
75
import pytensor.tensor as pt
@@ -12,7 +10,7 @@
1210
)
1311

1412
floatX = pytensor.config.floatX
15-
KeyLike = Union[tuple[str | int, ...], str]
13+
KeyLike = tuple[str | int, ...] | str
1614

1715

1816
class PytensorRepresentation:

pymc_experimental/statespace/filters/distributions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ def update(self, node: Node):
9696

9797

9898
class _LinearGaussianStateSpace(Continuous):
99-
rv_op = LinearGaussianStateSpaceRV
100-
10199
def __new__(
102100
cls,
103101
name,
@@ -360,8 +358,6 @@ def update(self, node: Node):
360358

361359

362360
class SequenceMvNormal(Continuous):
363-
rv_op = KalmanFilterRV
364-
365361
def __new__(cls, *args, **kwargs):
366362
return super().__new__(cls, *args, **kwargs)
367363

pymc_experimental/statespace/filters/kalman_filter.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,14 @@ def handle_missing_values(
351351
self, y, Z, H
352352
) -> tuple[TensorVariable, TensorVariable, TensorVariable, float]:
353353
"""
354-
This function handles missing values in the observation data `y` and adjusts the design matrix `Z` and the
355-
observation noise covariance matrix `H` accordingly. Missing values are replaced with zeros to prevent
356-
propagating NaNs through the computation. The function also returns a binary flag tensor `all_nan_flag`,
357-
indicating if all values in the observation data are missing. This flag is used for numerical adjustments in
358-
the update method.
354+
Handle missing values in the observation data `y`
355+
356+
Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns
357+
associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent
358+
propagating NaNs through the computation.
359+
360+
Return a binary flag tensor `all_nan_flag`,indicating if all values in the observation data are missing. This
361+
flag is used for numerical adjustments in the update method.
359362
360363
Parameters
361364
----------
@@ -660,7 +663,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
660663

661664

662665
class CholeskyFilter(BaseFilter):
663-
""" "
666+
"""
664667
Kalman filter with Cholesky factorization
665668
666669
Kalman filter implementation using a Cholesky factorization plus pt.solve_triangular to (attempt) to speed up
@@ -712,7 +715,7 @@ class SingleTimeseriesFilter(BaseFilter):
712715

713716
# TODO: This class should eventually be made irrelevant by pytensor re-writes.
714717
def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
715-
""" "
718+
"""
716719
Wrap the data in an `Assert` `Op` to ensure there is only one observed state.
717720
"""
718721
data = assert_data_is_1d(data, pt.eq(data.shape[1], 1))

0 commit comments

Comments
 (0)