Skip to content

Commit 2938d5a

Browse files
authored
feat: Handle missing libraries more robustly (#72)
* Handle missing libraries more robustly When 'pymc' or 'numba' or 'bridgestan' are missing imports, provide specific installation instructions. Don't rely on ImportError for missing packages, since that can be triggered downstream even when the respective package is installed. Have a single compile_X_model function for each X. Each function lazily imports X-related modules and handles missing imports. * Avoid `literal_unroll` with `numba.` prefix
1 parent e6c6c11 commit 2938d5a

File tree

3 files changed

+61
-34
lines changed

3 files changed

+61
-34
lines changed

python/nutpie/__init__.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,8 @@
11
from nutpie import _lib
22
from nutpie.sample import sample
33

4-
try:
5-
from .compile_pymc import compile_pymc_model
6-
except ImportError:
4+
from .compile_pymc import compile_pymc_model
5+
from .compile_stan import compile_stan_model
76

8-
def compile_pymc_model(*args, **kwargs):
9-
raise ValueError("Missing dependencies for pymc models. Install pymc.")
10-
11-
12-
try:
13-
from .compile_stan import compile_stan_model
14-
except ImportError:
15-
16-
def compile_stan_model(*args, **kwargs):
17-
raise ImportError("Missing dependencies for stan models. Install bridgestan.")
18-
19-
20-
__version__ = _lib.__version__
21-
22-
23-
__all__ = ["sample", "compile_pymc_model", "compile_stan_model"]
7+
__version__: str = _lib.__version__
8+
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]

python/nutpie/compile_pymc.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
import dataclasses
22
import itertools
33
from dataclasses import dataclass
4+
from importlib.util import find_spec
45
from math import prod
5-
from typing import Any, Dict, Optional, Tuple
6+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
67

7-
import numba
8-
import numba.core.ccallback
98
import numpy as np
109
import pandas as pd
11-
import pymc as pm
12-
import pytensor
13-
import pytensor.link.numba.dispatch
14-
import pytensor.tensor as pt
15-
from numba import literal_unroll
16-
from numba.cpython.unsafe.tuple import alloca_once, tuple_setitem
1710
from numpy.typing import NDArray
18-
from pymc.initial_point import make_initial_point_fn
1911

2012
from nutpie import _lib
2113
from nutpie.sample import CompiledModel
2214

15+
try:
16+
from numba.extending import intrinsic
17+
except ImportError:
2318

24-
@numba.extending.intrinsic
19+
def intrinsic(f):
20+
return f
21+
22+
23+
if TYPE_CHECKING:
24+
import numba.core.ccallback
25+
import pymc as pm
26+
27+
28+
@intrinsic
2529
def address_as_void_pointer(typingctx, src):
2630
"""returns a void pointer from a given memory address"""
2731
from numba.core import cgutils, types
@@ -36,8 +40,8 @@ def codegen(cgctx, builder, sig, args):
3640

3741
@dataclass(frozen=True)
3842
class CompiledPyMCModel(CompiledModel):
39-
compiled_logp_func: numba.core.ccallback.CFunc
40-
compiled_expand_func: numba.core.ccallback.CFunc
43+
compiled_logp_func: "numba.core.ccallback.CFunc"
44+
compiled_expand_func: "numba.core.ccallback.CFunc"
4145
shared_data: Dict[str, NDArray]
4246
user_data: NDArray
4347
n_expanded: int
@@ -144,7 +148,7 @@ def make_user_data(func, shared_data):
144148
return user_data
145149

146150

147-
def compile_pymc_model(model: pm.Model, **kwargs) -> CompiledPyMCModel:
151+
def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
148152
"""Compile necessary functions for sampling a pymc model.
149153
150154
Parameters
@@ -158,6 +162,21 @@ def compile_pymc_model(model: pm.Model, **kwargs) -> CompiledPyMCModel:
158162
A compiled model object.
159163
160164
"""
165+
if find_spec("pymc") is None:
166+
raise ImportError(
167+
"PyMC is not installed in the current environment. "
168+
"Please install it with something like "
169+
"'mamba install -c conda-forge pymc numba' "
170+
"and restart your kernel in case you are in an interactive session."
171+
)
172+
if find_spec("numba") is None:
173+
raise ImportError(
174+
"Numba is not installed in the current environment. "
175+
"Please install it with something like "
176+
"'mamba install -c conda-forge numba' "
177+
"and restart your kernel in case you are in an interactive session."
178+
)
179+
import numba
161180

162181
(
163182
n_dim,
@@ -220,6 +239,9 @@ def compile_pymc_model(model: pm.Model, **kwargs) -> CompiledPyMCModel:
220239

221240

222241
def _compute_shapes(model):
242+
import pytensor
243+
from pymc.initial_point import make_initial_point_fn
244+
223245
point = make_initial_point_fn(model=model, return_transformed=True)(0)
224246

225247
trace_vars = {
@@ -246,6 +268,10 @@ def _compute_shapes(model):
246268

247269

248270
def _make_functions(model):
271+
import pytensor
272+
import pytensor.link.numba.dispatch
273+
import pytensor.tensor as pt
274+
249275
shapes = _compute_shapes(model)
250276

251277
# Make logp_dlogp_function
@@ -358,6 +384,10 @@ def _make_functions(model):
358384

359385

360386
def make_extraction_fn(inner, shared_data, shared_vars, record_dtype):
387+
import numba
388+
from numba import literal_unroll
389+
from numba.cpython.unsafe.tuple import alloca_once, tuple_setitem
390+
361391
if not shared_vars:
362392

363393
@numba.njit(inline="always")
@@ -380,7 +410,7 @@ def extract_shared(x, user_data_):
380410
indices = tuple(range(len(names)))
381411
shared_tuple = tuple(shared_data[name] for name in shared_vars)
382412

383-
@numba.extending.intrinsic
413+
@intrinsic
384414
def tuple_setitem_literal(typingctx, tup, idx, val):
385415
"""Return a copy of the tuple with item at *idx* replaced with *val*."""
386416
if not isinstance(idx, numba.types.IntegerLiteral):
@@ -451,6 +481,8 @@ def extract_shared(x, user_data_):
451481

452482

453483
def _make_c_logp_func(n_dim, logp_fn, user_data, shared_logp, shared_data):
484+
import numba
485+
454486
extract = make_extraction_fn(logp_fn, shared_data, shared_logp, user_data.dtype)
455487

456488
c_sig = numba.types.int64(
@@ -490,6 +522,8 @@ def logp_numba(dim, x_, out_, logp_, user_data_):
490522
def _make_c_expand_func(
491523
n_dim, n_expanded, expand_fn, user_data, shared_vars, shared_data
492524
):
525+
import numba
526+
493527
extract = make_extraction_fn(expand_fn, shared_data, shared_vars, user_data.dtype)
494528

495529
c_sig = numba.types.int64(

python/nutpie/compile_stan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pathlib
33
import tempfile
44
from dataclasses import dataclass, replace
5+
from importlib.util import find_spec
56
from typing import Any, Dict, List, Optional
67

78
import numpy as np
@@ -113,6 +114,13 @@ def compile_stan_model(
113114
model_name: Optional[str] = None,
114115
cleanup: bool = True,
115116
) -> CompiledStanModel:
117+
if find_spec("bridgestan") is None:
118+
raise ImportError(
119+
"BridgeStan is not installed in the current environment. "
120+
"Please install it with something like "
121+
"'pip install bridgestan'."
122+
)
123+
116124
import bridgestan
117125

118126
if dims is None:

0 commit comments

Comments
 (0)