1
1
import dataclasses
2
2
import itertools
3
3
from dataclasses import dataclass
4
+ from importlib .util import find_spec
4
5
from math import prod
5
- from typing import Any , Dict , Optional , Tuple
6
+ from typing import TYPE_CHECKING , Any , Dict , Optional , Tuple
6
7
7
- import numba
8
- import numba .core .ccallback
9
8
import numpy as np
10
9
import pandas as pd
11
- import pymc as pm
12
- import pytensor
13
- import pytensor .link .numba .dispatch
14
- import pytensor .tensor as pt
15
- from numba import literal_unroll
16
- from numba .cpython .unsafe .tuple import alloca_once , tuple_setitem
17
10
from numpy .typing import NDArray
18
- from pymc .initial_point import make_initial_point_fn
19
11
20
12
from nutpie import _lib
21
13
from nutpie .sample import CompiledModel
22
14
15
+ try :
16
+ from numba .extending import intrinsic
17
+ except ImportError :
23
18
24
- @numba .extending .intrinsic
19
+ def intrinsic (f ):
20
+ return f
21
+
22
+
23
+ if TYPE_CHECKING :
24
+ import numba .core .ccallback
25
+ import pymc as pm
26
+
27
+
28
+ @intrinsic
25
29
def address_as_void_pointer (typingctx , src ):
26
30
"""returns a void pointer from a given memory address"""
27
31
from numba .core import cgutils , types
@@ -36,8 +40,8 @@ def codegen(cgctx, builder, sig, args):
36
40
37
41
@dataclass (frozen = True )
38
42
class CompiledPyMCModel (CompiledModel ):
39
- compiled_logp_func : numba .core .ccallback .CFunc
40
- compiled_expand_func : numba .core .ccallback .CFunc
43
+ compiled_logp_func : " numba.core.ccallback.CFunc"
44
+ compiled_expand_func : " numba.core.ccallback.CFunc"
41
45
shared_data : Dict [str , NDArray ]
42
46
user_data : NDArray
43
47
n_expanded : int
@@ -144,7 +148,7 @@ def make_user_data(func, shared_data):
144
148
return user_data
145
149
146
150
147
- def compile_pymc_model (model : pm .Model , ** kwargs ) -> CompiledPyMCModel :
151
+ def compile_pymc_model (model : " pm.Model" , ** kwargs ) -> CompiledPyMCModel :
148
152
"""Compile necessary functions for sampling a pymc model.
149
153
150
154
Parameters
@@ -158,6 +162,21 @@ def compile_pymc_model(model: pm.Model, **kwargs) -> CompiledPyMCModel:
158
162
A compiled model object.
159
163
160
164
"""
165
+ if find_spec ("pymc" ) is None :
166
+ raise ImportError (
167
+ "PyMC is not installed in the current environment. "
168
+ "Please install it with something like "
169
+ "'mamba install -c conda-forge pymc numba' "
170
+ "and restart your kernel in case you are in an interactive session."
171
+ )
172
+ if find_spec ("numba" ) is None :
173
+ raise ImportError (
174
+ "Numba is not installed in the current environment. "
175
+ "Please install it with something like "
176
+ "'mamba install -c conda-forge numba' "
177
+ "and restart your kernel in case you are in an interactive session."
178
+ )
179
+ import numba
161
180
162
181
(
163
182
n_dim ,
@@ -220,6 +239,9 @@ def compile_pymc_model(model: pm.Model, **kwargs) -> CompiledPyMCModel:
220
239
221
240
222
241
def _compute_shapes (model ):
242
+ import pytensor
243
+ from pymc .initial_point import make_initial_point_fn
244
+
223
245
point = make_initial_point_fn (model = model , return_transformed = True )(0 )
224
246
225
247
trace_vars = {
@@ -246,6 +268,10 @@ def _compute_shapes(model):
246
268
247
269
248
270
def _make_functions (model ):
271
+ import pytensor
272
+ import pytensor .link .numba .dispatch
273
+ import pytensor .tensor as pt
274
+
249
275
shapes = _compute_shapes (model )
250
276
251
277
# Make logp_dlogp_function
@@ -358,6 +384,10 @@ def _make_functions(model):
358
384
359
385
360
386
def make_extraction_fn (inner , shared_data , shared_vars , record_dtype ):
387
+ import numba
388
+ from numba import literal_unroll
389
+ from numba .cpython .unsafe .tuple import alloca_once , tuple_setitem
390
+
361
391
if not shared_vars :
362
392
363
393
@numba .njit (inline = "always" )
@@ -380,7 +410,7 @@ def extract_shared(x, user_data_):
380
410
indices = tuple (range (len (names )))
381
411
shared_tuple = tuple (shared_data [name ] for name in shared_vars )
382
412
383
- @numba . extending . intrinsic
413
+ @intrinsic
384
414
def tuple_setitem_literal (typingctx , tup , idx , val ):
385
415
"""Return a copy of the tuple with item at *idx* replaced with *val*."""
386
416
if not isinstance (idx , numba .types .IntegerLiteral ):
@@ -451,6 +481,8 @@ def extract_shared(x, user_data_):
451
481
452
482
453
483
def _make_c_logp_func (n_dim , logp_fn , user_data , shared_logp , shared_data ):
484
+ import numba
485
+
454
486
extract = make_extraction_fn (logp_fn , shared_data , shared_logp , user_data .dtype )
455
487
456
488
c_sig = numba .types .int64 (
@@ -490,6 +522,8 @@ def logp_numba(dim, x_, out_, logp_, user_data_):
490
522
def _make_c_expand_func (
491
523
n_dim , n_expanded , expand_fn , user_data , shared_vars , shared_data
492
524
):
525
+ import numba
526
+
493
527
extract = make_extraction_fn (expand_fn , shared_data , shared_vars , user_data .dtype )
494
528
495
529
c_sig = numba .types .int64 (
0 commit comments