22import itertools
33import warnings
44from dataclasses import dataclass
5- from functools import partial
5+ from functools import wraps
66from importlib .util import find_spec
77from math import prod
8- from typing import TYPE_CHECKING , Any , Callable , Literal , Optional
8+ from typing import TYPE_CHECKING , Any , Callable , Literal , Optional , Union
99
1010import numpy as np
1111import pandas as pd
1212from numpy .typing import NDArray
13+ from pymc .initial_point import make_initial_point_fn
1314
1415from nutpie import _lib
1516from nutpie .compiled_pyfunc import SeedType , from_pyfunc
@@ -26,6 +27,59 @@ def intrinsic(f):
2627if TYPE_CHECKING :
2728 import numba .core .ccallback
2829 import pymc as pm
30+ from pytensor .tensor import TensorVariable , Variable
31+
32+
33+ def rv_dict_to_flat_array_wrapper (
34+ fn : Callable [[SeedType ], dict [str , np .ndarray ]],
35+ names : list [str ],
36+ shapes : list [tuple [int ]],
37+ ) -> Callable [[SeedType ], np .ndarray ]:
38+ """
39+ Wraps a function that returns a dictionary of string:array key:value pairs
40+ and returns a single flat float64 array. Also checks that the shapes of
41+ the arrays match the expected shapes.
42+
43+ Parameters
44+ ----------
45+ fn: Callable
46+ Function that takes a seed and return a dictionary of variable names
47+ to initial values. This function should be the output of
48+ pymc.initial_point.make_initial_point_fn
49+ names: list of str
50+ List of random variable names in the model
51+ shapes: list of tuple of int
52+ Shape of random variables in the model
53+
54+ Returns
55+ -------
56+ seeded_array_fn: Callable
57+ Function that takes a seed and returns a flat, contiguous float64
58+ array of initial values. The ordering of the random variables inside
59+ the array is controlled by the ``names`` parameter.
60+ """
61+
62+ @wraps (fn )
63+ def seeded_array_fn (seed : SeedType = None ):
64+ inital_value_dict = fn (seed )
65+ total_size = sum (np .prod (shape ) for shape in shapes )
66+ flat_array = np .empty (total_size , dtype = "float64" , order = "C" )
67+ cursor = 0
68+
69+ for name , shape in zip (names , shapes ):
70+ initial_value = inital_value_dict [name ]
71+ n = int (np .prod (initial_value .shape ))
72+ if initial_value .shape != shape :
73+ raise ValueError (
74+ f"Size of initial value for { name } is { initial_value .shape } , "
75+ f"expected { shape } "
76+ )
77+ flat_array [cursor : cursor + n ] = initial_value .ravel ().astype ("float64" )
78+ cursor += n
79+
80+ return flat_array
81+
82+ return seeded_array_fn
2983
3084
3185@intrinsic
@@ -159,7 +213,11 @@ def make_user_data(shared_vars, shared_data):
159213 return user_data
160214
161215
162- def _compile_pymc_model_numba (model : "pm.Model" , ** kwargs ) -> CompiledPyMCModel :
216+ def _compile_pymc_model_numba (
217+ model : "pm.Model" ,
218+ initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
219+ ** kwargs ,
220+ ) -> CompiledPyMCModel :
163221 if find_spec ("numba" ) is None :
164222 raise ImportError (
165223 "Numba is not installed in the current environment. "
@@ -174,7 +232,6 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
174232 n_expanded ,
175233 logp_fn_pt ,
176234 expand_fn_pt ,
177- initial_fn_pt ,
178235 shape_info ,
179236 ) = _make_functions (model , mode = "NUMBA" , compute_grad = True , join_expanded = True )
180237
@@ -223,15 +280,17 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
223280 expand_numba = numba .cfunc (c_sig_expand , ** kwargs )(expand_numba_raw )
224281
225282 dims , coords = _prepare_dims_and_coords (model , shape_info )
226-
283+ initial_point_fn_array = rv_dict_to_flat_array_wrapper (
284+ initial_point_fn , names = shape_info [0 ], shapes = shape_info [- 1 ]
285+ )
227286 return CompiledPyMCModel (
228287 _n_dim = n_dim ,
229288 dims = dims ,
230289 _coords = coords ,
231290 _shapes = {name : tuple (shape ) for name , _ , shape in zip (* shape_info )},
232291 compiled_logp_func = logp_numba ,
233292 compiled_expand_func = expand_numba ,
234- initial_point_func = initial_fn_pt ,
293+ initial_point_func = initial_point_fn_array ,
235294 shared_data = shared_data ,
236295 user_data = user_data ,
237296 n_expanded = n_expanded ,
@@ -266,7 +325,13 @@ def _prepare_dims_and_coords(model, shape_info):
266325 return dims , coords
267326
268327
269- def _compile_pymc_model_jax (model , * , gradient_backend = None , ** kwargs ):
328+ def _compile_pymc_model_jax (
329+ model ,
330+ * ,
331+ gradient_backend = None ,
332+ initial_point_fn : Callable [[SeedType ], dict [str , np .ndarray ]],
333+ ** kwargs ,
334+ ):
270335 if find_spec ("jax" ) is None :
271336 raise ImportError (
272337 "Jax is not installed in the current environment. "
@@ -286,7 +351,6 @@ def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs):
286351 _ ,
287352 logp_fn_pt ,
288353 expand_fn_pt ,
289- make_initial_point_py ,
290354 shape_info ,
291355 ) = _make_functions (
292356 model ,
@@ -343,11 +407,15 @@ def expand(x, **shared):
343407
344408 dims , coords = _prepare_dims_and_coords (model , shape_info )
345409
410+ initial_point_fn_array = rv_dict_to_flat_array_wrapper (
411+ initial_point_fn , names = shape_info [0 ], shapes = shape_info [- 1 ]
412+ )
413+
346414 return from_pyfunc (
347415 ndim = n_dim ,
348416 make_logp_fn = make_logp_func ,
349417 make_expand_fn = make_expand_func ,
350- make_initial_point_fn = make_initial_point_py ,
418+ make_initial_point_fn = initial_point_fn_array ,
351419 expanded_dtypes = dtypes ,
352420 expanded_shapes = shapes ,
353421 expanded_names = names ,
@@ -362,6 +430,9 @@ def compile_pymc_model(
362430 * ,
363431 backend : Literal ["numba" , "jax" ] = "numba" ,
364432 gradient_backend : Literal ["pytensor" , "jax" ] = "pytensor" ,
433+ overrides : dict [Union ["Variable" , str ], np .ndarray | float | int ] | None = None ,
434+ jitter_rvs : set ["TensorVariable" ] | None = None ,
435+ default_strategy : Literal ["support_point" , "prior" ] = "support_point" ,
365436 ** kwargs ,
366437) -> CompiledModel :
367438 """Compile necessary functions for sampling a pymc model.
@@ -375,7 +446,13 @@ def compile_pymc_model(
375446 gradient_backend: ["pytensor", "jax"]
376447 Which library is used to compute the gradients. This can only be
377448 changed to "jax" if the jax backend is used.
378-
449+ jitter_rvs : set
450+ The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
451+ added to the initial value. Only available for variables that have a transform or real-valued support.
452+ default_strategy : str
453+ Which of { "support_point", "prior" } to prefer if the initval setting for an RV is None.
454+ overrides : dict
455+ Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
379456 Returns
380457 -------
381458 compiled_model : CompiledPyMCModel
@@ -390,13 +467,26 @@ def compile_pymc_model(
390467 "and restart your kernel in case you are in an interactive session."
391468 )
392469
470+ initial_point_fn = make_initial_point_fn (
471+ model = model ,
472+ overrides = overrides ,
473+ default_strategy = default_strategy ,
474+ jitter_rvs = jitter_rvs ,
475+ return_transformed = False ,
476+ )
477+
393478 if backend .lower () == "numba" :
394479 if gradient_backend == "jax" :
395480 raise ValueError ("Gradient backend cannot be jax when using numba backend" )
396- return _compile_pymc_model_numba (model , ** kwargs )
481+ return _compile_pymc_model_numba (
482+ model , initial_point_fn = initial_point_fn , ** kwargs
483+ )
397484 elif backend .lower () == "jax" :
398485 return _compile_pymc_model_jax (
399- model , gradient_backend = gradient_backend , ** kwargs
486+ model ,
487+ gradient_backend = gradient_backend ,
488+ initial_point_fn = initial_point_fn ,
489+ ** kwargs ,
400490 )
401491 else :
402492 raise ValueError (f"Backend must be one of numba and jax. Got { backend } " )
@@ -434,12 +524,7 @@ def _compute_shapes(model):
434524def _make_functions (
435525 model , * , mode , compute_grad , join_expanded
436526) -> tuple [
437- int ,
438- int ,
439- Callable ,
440- Callable ,
441- Callable ,
442- tuple [list [str ], list [slice ], list [tuple [int , ...]]]
527+ int , int , Callable , Callable , tuple [list [str ], list [slice ], list [tuple [int , ...]]]
443528]:
444529 """
445530 Compile functions required by nuts-rs from a given PyMC model.
@@ -468,18 +553,14 @@ def _make_functions(
468553 and the gradient, otherwise only the logp is returned.
469554 expand_fn_pt: Callable
470555 Compiled pytensor function that computes the remaining variables for the trace
471- init_point_fn_pt: Callable
472- ...
473556 param_data: tuple of lists
474557 Tuple containing data necessary to unravel a flat array of model variables back into a ragged list of arrays.
475558 The first list contains the names of the variables, the second list contains the slices that correspond to the
476559 variables in the flat array, and the third list contains the shapes of the variables.
477560 """
478561 import pytensor
479562 import pytensor .tensor as pt
480- from pymc .initial_point import make_initial_point_fn
481563 from pymc .pytensorf import compile_pymc
482- from pymc .initial_point import make_initial_point_fn
483564
484565 shapes = _compute_shapes (model )
485566
@@ -549,10 +630,6 @@ def _make_functions(
549630 with model :
550631 logp_fn_pt = compile_pymc ((joined ,), (logp ,), mode = mode )
551632
552- make_initial_point_py = partial (make_initial_point_fn ,
553- model = model ,
554- return_transformed = True )
555-
556633 # Make function that computes remaining variables for the trace
557634 remaining_rvs = [
558635 var for var in model .unobserved_value_vars if var .name not in joined_names
@@ -591,7 +668,6 @@ def _make_functions(
591668 num_expanded ,
592669 logp_fn_pt ,
593670 expand_fn_pt ,
594- make_initial_point_py ,
595671 (all_names , all_slices , all_shapes ),
596672 )
597673
0 commit comments