Skip to content

Commit 9f936b5

Browse files
Fix bug introduced by code refactor
1 parent 075b47e commit 9f936b5

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

python/nutpie/compile_pymc.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,20 @@ def rv_dict_to_flat_array_wrapper(
6161

6262
@wraps(fn)
6363
def seeded_array_fn(seed: SeedType = None):
64-
inital_value_dict = fn(seed)
65-
total_size = sum(np.prod(shape) for shape in shapes)
64+
initial_value_dict = fn(seed)
65+
total_size = sum(np.prod(shape).astype(int) for shape in shapes)
6666
flat_array = np.empty(total_size, dtype="float64", order="C")
6767
cursor = 0
6868

6969
for name, shape in zip(names, shapes):
70-
initial_value = inital_value_dict[name]
70+
initial_value = initial_value_dict[name]
7171
n = int(np.prod(initial_value.shape))
7272
if initial_value.shape != shape:
7373
raise ValueError(
7474
f"Size of initial value for {name} is {initial_value.shape}, "
7575
f"expected {shape}"
7676
)
77+
7778
flat_array[cursor : cursor + n] = initial_value.ravel().astype("float64")
7879
cursor += n
7980

@@ -144,16 +145,16 @@ def with_data(self, **updates):
144145
user_data=user_data,
145146
)
146147

147-
def _make_sampler(self, settings, init_mean, cores, progress_type):
148-
model = self._make_model(init_mean)
148+
def _make_sampler(self, settings, cores, progress_type):
149+
model = self._make_model()
149150
return _lib.PySampler.from_pymc(
150151
settings,
151152
cores,
152153
model,
153154
progress_type,
154155
)
155156

156-
def _make_model(self, init_mean):
157+
def _make_model(self):
157158
expand_fn = _lib.ExpandFunc(
158159
self.n_dim,
159160
self.n_expanded,
@@ -169,14 +170,15 @@ def _make_model(self, init_mean):
169170
)
170171

171172
var_sizes = [prod(shape) for shape in self.shape_info[2]]
173+
var_names = self.shape_info[0]
172174

173175
return _lib.PyMcModel(
174176
self.n_dim,
175177
logp_fn,
176178
expand_fn,
177179
self.initial_point_func,
178180
var_sizes,
179-
self.shape_info[0],
181+
var_names,
180182
)
181183

182184

@@ -472,7 +474,7 @@ def compile_pymc_model(
472474
overrides=overrides,
473475
default_strategy=default_strategy,
474476
jitter_rvs=jitter_rvs,
475-
return_transformed=False,
477+
return_transformed=True,
476478
)
477479

478480
if backend.lower() == "numba":

python/nutpie/compiled_pyfunc.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def from_pyfunc(
7373
ndim: int,
7474
make_logp_fn: Callable,
7575
make_expand_fn: Callable,
76-
make_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
76+
make_initial_point_fn: Callable[[SeedType], np.ndarray],
7777
expanded_dtypes: list[np.dtype],
7878
expanded_shapes: list[tuple[int, ...]],
7979
expanded_names: list[str],
@@ -102,19 +102,13 @@ def from_pyfunc(
102102
if shared_data is None:
103103
shared_data = {}
104104

105-
from nutpie.compile_pymc import rv_dict_to_flat_array_wrapper
106-
107-
initial_point_fn = rv_dict_to_flat_array_wrapper(
108-
make_initial_point_fn, names=expanded_names, shapes=expanded_shapes
109-
)
110-
111105
return PyFuncModel(
112106
_n_dim=ndim,
113107
dims=dims,
114108
_coords=coords,
115109
_make_logp_func=make_logp_fn,
116110
_make_expand_func=make_expand_fn,
117-
_make_initial_points=initial_point_fn,
111+
_make_initial_points=make_initial_point_fn,
118112
_variables=variables,
119113
_shared_data=shared_data,
120114
)

0 commit comments

Comments
 (0)