Skip to content

Commit e93c7bb

Browse files
committed
style: Fix pre-commit issues
1 parent e42e972 commit e93c7bb

File tree

6 files changed

+60
-29
lines changed

6 files changed

+60
-29
lines changed

python/nutpie/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
from nutpie import _lib
12
from nutpie.compile_pymc import compile_pymc_model
23
from nutpie.compile_stan import compile_stan_model
34
from nutpie.sample import sample
45

5-
from nutpie import _lib
6-
76
__version__: str = _lib.__version__
87
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]

python/nutpie/compile_pymc.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from dataclasses import dataclass
55
from importlib.util import find_spec
66
from math import prod
7-
from typing import TYPE_CHECKING, Any, Optional
7+
from typing import TYPE_CHECKING, Any, Literal, Optional
88

99
import numpy as np
1010
import pandas as pd
1111
from numpy.typing import NDArray
12-
from nutpie.compiled_pyfunc import from_pyfunc
13-
from nutpie.sample import CompiledModel
1412

1513
from nutpie import _lib
14+
from nutpie.compiled_pyfunc import from_pyfunc
15+
from nutpie.sample import CompiledModel
1616

1717
try:
1818
from numba.extending import intrinsic
@@ -184,7 +184,7 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
184184
for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]:
185185
if val.name in shared_data and val not in seen:
186186
raise ValueError(f"Shared variables must have unique names: {val.name}")
187-
shared_data[val.name] = val.get_value().copy()
187+
shared_data[val.name] = val.get_value()
188188
shared_vars[val.name] = val
189189
seen.add(val)
190190

@@ -308,7 +308,7 @@ def logp_fn_jax_grad(x, **shared):
308308
for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]:
309309
if val.name in shared_data and val not in seen:
310310
raise ValueError(f"Shared variables must have unique names: {val.name}")
311-
shared_data[val.name] = jax.numpy.asarray(val.get_value().copy())
311+
shared_data[val.name] = jax.numpy.asarray(val.get_value())
312312
shared_vars[val.name] = val
313313
seen.add(val)
314314

@@ -356,8 +356,12 @@ def expand(x, **shared):
356356

357357

358358
def compile_pymc_model(
359-
model: "pm.Model", *, backend="numba", gradient_backend=None, **kwargs
360-
) -> CompiledPyMCModel:
359+
model: "pm.Model",
360+
*,
361+
backend: Literal["numba", "jax"] = "numba",
362+
gradient_backend: Literal["pytensor", "jax"] | None = None,
363+
**kwargs,
364+
) -> CompiledModel:
361365
"""Compile necessary functions for sampling a pymc model.
362366
363367
Parameters

python/nutpie/compile_stan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import numpy as np
99
import pandas as pd
1010
from numpy.typing import NDArray
11-
from nutpie.sample import CompiledModel
1211

1312
from nutpie import _lib
13+
from nutpie.sample import CompiledModel
1414

1515

1616
class _NumpyArrayEncoder(json.JSONEncoder):

python/nutpie/compiled_pyfunc.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import dataclasses
22
from dataclasses import dataclass
33
from functools import partial
4-
from typing import Any, Callable, List
4+
from typing import Any, Callable
55

66
import numpy as np
7-
from nutpie.sample import CompiledModel
87

98
from nutpie import _lib
9+
from nutpie.sample import CompiledModel
1010

1111

1212
@dataclass(frozen=True)
@@ -15,7 +15,7 @@ class PyFuncModel(CompiledModel):
1515
_make_expand_func: Callable
1616
_shared_data: dict[str, Any]
1717
_n_dim: int
18-
_variables: List[_lib.PyVariable]
18+
_variables: list[_lib.PyVariable]
1919
_coords: dict[str, Any]
2020

2121
@property
@@ -66,17 +66,17 @@ def make_expand_func(seed1, seed2, chain):
6666

6767

6868
def from_pyfunc(
69-
ndim,
70-
make_logp_fn,
71-
make_expand_fn,
72-
expanded_dtypes,
73-
expanded_shapes,
74-
expanded_names,
69+
ndim: int,
70+
make_logp_fn: Callable,
71+
make_expand_fn: Callable,
72+
expanded_dtypes: list[np.dtype],
73+
expanded_shapes: list[tuple[int, ...]],
74+
expanded_names: list[str],
7575
*,
76-
initial_mean=None,
77-
coords=None,
78-
dims=None,
79-
shared_data=None,
76+
initial_mean: np.ndarray | None = None,
77+
coords: dict[str, Any] | None = None,
78+
dims: dict[str, tuple[str, ...]] | None = None,
79+
shared_data: dict[str, Any] | None = None,
8080
):
8181
variables = []
8282
for name, shape, dtype in zip(
@@ -91,6 +91,13 @@ def from_pyfunc(
9191
dtype = _lib.ExpandDtype.int64_array(shape)
9292
variables.append(_lib.PyVariable(name, dtype))
9393

94+
if coords is None:
95+
coords = {}
96+
if dims is None:
97+
dims = {}
98+
if shared_data is None:
99+
shared_data = {}
100+
94101
if shared_data is None:
95102
shared_data = dict()
96103
return PyFuncModel(

python/nutpie/sample.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,11 @@ def in_colab():
284284
shell = get_ipython().__class__.__name__
285285
if shell == "ZMQInteractiveShell": # Jupyter notebook, Spyder or qtconsole
286286
try:
287-
from IPython.display import (HTML, clear_output, # noqa: F401
288-
display)
287+
from IPython.display import (
288+
HTML, # noqa: F401
289+
clear_output, # noqa: F401
290+
display, # noqa: F401
291+
)
289292

290293
return True
291294
except ImportError:

tests/test_pymc.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import nutpie
66
import nutpie.compile_pymc
77

8-
98
parameterize_backends = pytest.mark.parametrize(
109
"backend, gradient_backend",
1110
[("numba", None), ("jax", "pytensor"), ("jax", "jax")],
@@ -131,8 +130,8 @@ def test_det(backend, gradient_backend):
131130
@parameterize_backends
132131
def test_pymc_model_shared(backend, gradient_backend):
133132
with pm.Model() as model:
134-
mu = pm.MutableData("mu", 0.1)
135-
sigma = pm.MutableData("sigma", np.ones(3))
133+
mu = pm.Data("mu", 0.1)
134+
sigma = pm.Data("sigma", np.ones(3))
136135
pm.Normal("a", mu=mu, sigma=sigma, shape=3)
137136

138137
compiled = nutpie.compile_pymc_model(
@@ -150,7 +149,26 @@ def test_pymc_model_shared(backend, gradient_backend):
150149
nutpie.sample(compiled3, chains=1)
151150

152151

153-
@parameterize_backends
152+
@pytest.mark.parametrize(
153+
("backend", "gradient_backend"),
154+
[
155+
("numba", None),
156+
pytest.param(
157+
"jax",
158+
"pytensor",
159+
marks=pytest.mark.xfail(
160+
reason="https://github.com/pymc-devs/pytensor/issues/853"
161+
),
162+
),
163+
pytest.param(
164+
"jax",
165+
"jax",
166+
marks=pytest.mark.xfail(
167+
reason="https://github.com/pymc-devs/pytensor/issues/853"
168+
),
169+
),
170+
],
171+
)
154172
def test_missing(backend, gradient_backend):
155173
with pm.Model(coords={"obs": range(4)}) as model:
156174
mu = pm.Normal("mu")

0 commit comments

Comments
 (0)