Skip to content

Commit 075b47e

Browse files
Refactor where the intial point function is created
1 parent 9dc8a98 commit 075b47e

File tree

2 files changed

+109
-96
lines changed

2 files changed

+109
-96
lines changed

python/nutpie/compile_pymc.py

Lines changed: 103 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import itertools
33
import warnings
44
from dataclasses import dataclass
5-
from functools import partial
5+
from functools import wraps
66
from importlib.util import find_spec
77
from math import prod
8-
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
8+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
99

1010
import numpy as np
1111
import pandas as pd
1212
from numpy.typing import NDArray
13+
from pymc.initial_point import make_initial_point_fn
1314

1415
from nutpie import _lib
1516
from nutpie.compiled_pyfunc import SeedType, from_pyfunc
@@ -26,6 +27,59 @@ def intrinsic(f):
2627
if TYPE_CHECKING:
2728
import numba.core.ccallback
2829
import pymc as pm
30+
from pytensor.tensor import TensorVariable, Variable
31+
32+
33+
def rv_dict_to_flat_array_wrapper(
34+
fn: Callable[[SeedType], dict[str, np.ndarray]],
35+
names: list[str],
36+
shapes: list[tuple[int]],
37+
) -> Callable[[SeedType], np.ndarray]:
38+
"""
39+
Wraps a function that returns a dictionary of string:array key:value pairs
40+
and returns a single flat float64 array. Also checks that the shapes of
41+
the arrays match the expected shapes.
42+
43+
Parameters
44+
----------
45+
fn: Callable
46+
Function that takes a seed and return a dictionary of variable names
47+
to initial values. This function should be the output of
48+
pymc.initial_point.make_initial_point_fn
49+
names: list of str
50+
List of random variable names in the model
51+
shapes: list of tuple of int
52+
Shape of random variables in the model
53+
54+
Returns
55+
-------
56+
seeded_array_fn: Callable
57+
Function that takes a seed and returns a flat, contiguous float64
58+
array of initial values. The ordering of the random variables inside
59+
the array is controlled by the ``names`` parameter.
60+
"""
61+
62+
@wraps(fn)
63+
def seeded_array_fn(seed: SeedType = None):
64+
inital_value_dict = fn(seed)
65+
total_size = sum(np.prod(shape) for shape in shapes)
66+
flat_array = np.empty(total_size, dtype="float64", order="C")
67+
cursor = 0
68+
69+
for name, shape in zip(names, shapes):
70+
initial_value = inital_value_dict[name]
71+
n = int(np.prod(initial_value.shape))
72+
if initial_value.shape != shape:
73+
raise ValueError(
74+
f"Size of initial value for {name} is {initial_value.shape}, "
75+
f"expected {shape}"
76+
)
77+
flat_array[cursor : cursor + n] = initial_value.ravel().astype("float64")
78+
cursor += n
79+
80+
return flat_array
81+
82+
return seeded_array_fn
2983

3084

3185
@intrinsic
@@ -159,7 +213,11 @@ def make_user_data(shared_vars, shared_data):
159213
return user_data
160214

161215

162-
def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
216+
def _compile_pymc_model_numba(
217+
model: "pm.Model",
218+
initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
219+
**kwargs,
220+
) -> CompiledPyMCModel:
163221
if find_spec("numba") is None:
164222
raise ImportError(
165223
"Numba is not installed in the current environment. "
@@ -174,7 +232,6 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
174232
n_expanded,
175233
logp_fn_pt,
176234
expand_fn_pt,
177-
initial_fn_pt,
178235
shape_info,
179236
) = _make_functions(model, mode="NUMBA", compute_grad=True, join_expanded=True)
180237

@@ -223,15 +280,17 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
223280
expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)
224281

225282
dims, coords = _prepare_dims_and_coords(model, shape_info)
226-
283+
initial_point_fn_array = rv_dict_to_flat_array_wrapper(
284+
initial_point_fn, names=shape_info[0], shapes=shape_info[-1]
285+
)
227286
return CompiledPyMCModel(
228287
_n_dim=n_dim,
229288
dims=dims,
230289
_coords=coords,
231290
_shapes={name: tuple(shape) for name, _, shape in zip(*shape_info)},
232291
compiled_logp_func=logp_numba,
233292
compiled_expand_func=expand_numba,
234-
initial_point_func=initial_fn_pt,
293+
initial_point_func=initial_point_fn_array,
235294
shared_data=shared_data,
236295
user_data=user_data,
237296
n_expanded=n_expanded,
@@ -266,7 +325,13 @@ def _prepare_dims_and_coords(model, shape_info):
266325
return dims, coords
267326

268327

269-
def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs):
328+
def _compile_pymc_model_jax(
329+
model,
330+
*,
331+
gradient_backend=None,
332+
initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
333+
**kwargs,
334+
):
270335
if find_spec("jax") is None:
271336
raise ImportError(
272337
"Jax is not installed in the current environment. "
@@ -286,7 +351,6 @@ def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs):
286351
_,
287352
logp_fn_pt,
288353
expand_fn_pt,
289-
make_initial_point_py,
290354
shape_info,
291355
) = _make_functions(
292356
model,
@@ -343,11 +407,15 @@ def expand(x, **shared):
343407

344408
dims, coords = _prepare_dims_and_coords(model, shape_info)
345409

410+
initial_point_fn_array = rv_dict_to_flat_array_wrapper(
411+
initial_point_fn, names=shape_info[0], shapes=shape_info[-1]
412+
)
413+
346414
return from_pyfunc(
347415
ndim=n_dim,
348416
make_logp_fn=make_logp_func,
349417
make_expand_fn=make_expand_func,
350-
make_initial_point_fn=make_initial_point_py,
418+
make_initial_point_fn=initial_point_fn_array,
351419
expanded_dtypes=dtypes,
352420
expanded_shapes=shapes,
353421
expanded_names=names,
@@ -362,6 +430,9 @@ def compile_pymc_model(
362430
*,
363431
backend: Literal["numba", "jax"] = "numba",
364432
gradient_backend: Literal["pytensor", "jax"] = "pytensor",
433+
overrides: dict[Union["Variable", str], np.ndarray | float | int] | None = None,
434+
jitter_rvs: set["TensorVariable"] | None = None,
435+
default_strategy: Literal["support_point", "prior"] = "support_point",
365436
**kwargs,
366437
) -> CompiledModel:
367438
"""Compile necessary functions for sampling a pymc model.
@@ -375,7 +446,13 @@ def compile_pymc_model(
375446
gradient_backend: ["pytensor", "jax"]
376447
Which library is used to compute the gradients. This can only be
377448
changed to "jax" if the jax backend is used.
378-
449+
jitter_rvs : set
450+
The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
451+
added to the initial value. Only available for variables that have a transform or real-valued support.
452+
default_strategy : str
453+
Which of { "support_point", "prior" } to prefer if the initval setting for an RV is None.
454+
overrides : dict
455+
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
379456
Returns
380457
-------
381458
compiled_model : CompiledPyMCModel
@@ -390,13 +467,26 @@ def compile_pymc_model(
390467
"and restart your kernel in case you are in an interactive session."
391468
)
392469

470+
initial_point_fn = make_initial_point_fn(
471+
model=model,
472+
overrides=overrides,
473+
default_strategy=default_strategy,
474+
jitter_rvs=jitter_rvs,
475+
return_transformed=False,
476+
)
477+
393478
if backend.lower() == "numba":
394479
if gradient_backend == "jax":
395480
raise ValueError("Gradient backend cannot be jax when using numba backend")
396-
return _compile_pymc_model_numba(model, **kwargs)
481+
return _compile_pymc_model_numba(
482+
model, initial_point_fn=initial_point_fn, **kwargs
483+
)
397484
elif backend.lower() == "jax":
398485
return _compile_pymc_model_jax(
399-
model, gradient_backend=gradient_backend, **kwargs
486+
model,
487+
gradient_backend=gradient_backend,
488+
initial_point_fn=initial_point_fn,
489+
**kwargs,
400490
)
401491
else:
402492
raise ValueError(f"Backend must be one of numba and jax. Got {backend}")
@@ -434,12 +524,7 @@ def _compute_shapes(model):
434524
def _make_functions(
435525
model, *, mode, compute_grad, join_expanded
436526
) -> tuple[
437-
int,
438-
int,
439-
Callable,
440-
Callable,
441-
Callable,
442-
tuple[list[str], list[slice], list[tuple[int, ...]]]
527+
int, int, Callable, Callable, tuple[list[str], list[slice], list[tuple[int, ...]]]
443528
]:
444529
"""
445530
Compile functions required by nuts-rs from a given PyMC model.
@@ -468,18 +553,14 @@ def _make_functions(
468553
and the gradient, otherwise only the logp is returned.
469554
expand_fn_pt: Callable
470555
Compiled pytensor function that computes the remaining variables for the trace
471-
init_point_fn_pt: Callable
472-
...
473556
param_data: tuple of lists
474557
Tuple containing data necessary to unravel a flat array of model variables back into a ragged list of arrays.
475558
The first list contains the names of the variables, the second list contains the slices that correspond to the
476559
variables in the flat array, and the third list contains the shapes of the variables.
477560
"""
478561
import pytensor
479562
import pytensor.tensor as pt
480-
from pymc.initial_point import make_initial_point_fn
481563
from pymc.pytensorf import compile_pymc
482-
from pymc.initial_point import make_initial_point_fn
483564

484565
shapes = _compute_shapes(model)
485566

@@ -549,10 +630,6 @@ def _make_functions(
549630
with model:
550631
logp_fn_pt = compile_pymc((joined,), (logp,), mode=mode)
551632

552-
make_initial_point_py = partial(make_initial_point_fn,
553-
model=model,
554-
return_transformed=True)
555-
556633
# Make function that computes remaining variables for the trace
557634
remaining_rvs = [
558635
var for var in model.unobserved_value_vars if var.name not in joined_names
@@ -591,7 +668,6 @@ def _make_functions(
591668
num_expanded,
592669
logp_fn_pt,
593670
expand_fn_pt,
594-
make_initial_point_py,
595671
(all_names, all_slices, all_shapes),
596672
)
597673

python/nutpie/compiled_pyfunc.py

Lines changed: 6 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import dataclasses
22
from dataclasses import dataclass
3-
from functools import partial, wraps
4-
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
3+
from functools import partial
4+
from typing import Any, Callable
55

66
import numpy as np
77

@@ -11,62 +11,6 @@
1111
SeedType = int | float | np.random.Generator | None
1212

1313

14-
if TYPE_CHECKING:
15-
from pytensor.tensor import TensorVariable, Variable
16-
17-
18-
def rv_dict_to_flat_array_wrapper(
19-
fn: Callable[[SeedType], dict[str, np.ndarray]],
20-
names: list[str],
21-
shapes: list[tuple[int]],
22-
) -> Callable[[SeedType], np.ndarray]:
23-
"""
24-
Wraps a function that returns a dictionary of string:array key:value pairs
25-
and returns a single flat float64 array. Also checks that the shapes of
26-
the arrays match the expected shapes.
27-
28-
Parameters
29-
----------
30-
fn: Callable
31-
Function that takes a seed and return a dictionary of variable names
32-
to initial values. This function should be the output of
33-
pymc.initial_point.make_initial_point_fn
34-
names: list of str
35-
List of random variable names in the model
36-
shapes: list of tuple of int
37-
Shape of random variables in the model
38-
39-
Returns
40-
-------
41-
seeded_array_fn: Callable
42-
Function that takes a seed and returns a flat, contiguous float64
43-
array of initial values. The ordering of the random variables inside
44-
the array is controlled by the ``names`` parameter.
45-
"""
46-
47-
@wraps(fn)
48-
def seeded_array_fn(seed: SeedType = None):
49-
inital_value_dict = fn(seed)
50-
total_size = sum(np.prod(shape) for shape in shapes)
51-
flat_array = np.empty(total_size, dtype="float64", order="C")
52-
cursor = 0
53-
54-
for name, shape in zip(names, shapes):
55-
initial_value = inital_value_dict[name]
56-
n = int(np.prod(initial_value.shape))
57-
if initial_value.shape != shape:
58-
raise ValueError(
59-
f"Size of initial value for {name} is {initial_value.shape}, "
60-
f"expected {shape}"
61-
)
62-
flat_array[cursor : cursor + n] = initial_value.ravel().astype("float64")
63-
cursor += n
64-
65-
return flat_array
66-
67-
return seeded_array_fn
68-
69-
7014
@dataclass(frozen=True)
7115
class PyFuncModel(CompiledModel):
7216
_make_logp_func: Callable
@@ -129,15 +73,11 @@ def from_pyfunc(
12973
ndim: int,
13074
make_logp_fn: Callable,
13175
make_expand_fn: Callable,
132-
make_initial_point_fn: Callable[[Any, Any, Any], Callable[[SeedType], np.ndarray]],
76+
make_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
13377
expanded_dtypes: list[np.dtype],
13478
expanded_shapes: list[tuple[int, ...]],
13579
expanded_names: list[str],
13680
*,
137-
initial_values: dict[Union["Variable", str], np.ndarray | float | int]
138-
| None = None,
139-
jitter_rvs: set["TensorVariable"] | None = None,
140-
default_initialization: Literal["support_point", "prior"] = "support_point",
14181
coords: dict[str, Any] | None = None,
14282
dims: dict[str, tuple[str, ...]] | None = None,
14383
shared_data: dict[str, Any] | None = None,
@@ -162,13 +102,10 @@ def from_pyfunc(
162102
if shared_data is None:
163103
shared_data = {}
164104

165-
initial_point_fn = make_initial_point_fn(
166-
overrides=initial_values,
167-
default_strategy=default_initialization,
168-
jitter_rvs=jitter_rvs,
169-
)
105+
from nutpie.compile_pymc import rv_dict_to_flat_array_wrapper
106+
170107
initial_point_fn = rv_dict_to_flat_array_wrapper(
171-
initial_point_fn, names=expanded_names, shapes=expanded_shapes
108+
make_initial_point_fn, names=expanded_names, shapes=expanded_shapes
172109
)
173110

174111
return PyFuncModel(

0 commit comments

Comments
 (0)