|
33 | 33 | )
|
34 | 34 | from pymc.backends.ndarray import NDArray
|
35 | 35 | from pymc.blocking import DictToArrayBijection
|
| 36 | +from pymc.initial_point import make_initial_point_expression |
36 | 37 | from pymc.model import Point, modelcontext
|
37 |
| -from pymc.sampling.forward import sample_prior_predictive |
| 38 | +from pymc.sampling.forward import draw |
38 | 39 | from pymc.step_methods.metropolis import MultivariateNormalProposal
|
39 | 40 | from pymc.vartypes import discrete_types
|
40 | 41 |
|
@@ -182,13 +183,20 @@ def initialize_population(self) -> Dict[str, np.ndarray]:
|
182 | 183 | "ignore", category=UserWarning, message="The effect of Potentials"
|
183 | 184 | )
|
184 | 185 |
|
185 |
| - result = sample_prior_predictive( |
186 |
| - self.draws, |
187 |
| - var_names=[v.name for v in self.model.unobserved_value_vars], |
188 |
| - model=self.model, |
189 |
| - return_inferencedata=False, |
| 186 | + model = self.model |
| 187 | + prior_expression = make_initial_point_expression( |
| 188 | + free_rvs=model.free_RVs, |
| 189 | + rvs_to_transforms=model.rvs_to_transforms, |
| 190 | + initval_strategies={}, |
| 191 | + default_strategy="prior", |
| 192 | + return_transformed=True, |
190 | 193 | )
|
191 |
| - return cast(Dict[str, np.ndarray], result) |
| 194 | + prior_values = draw(prior_expression, draws=self.draws, random_seed=self.rng) |
| 195 | + |
| 196 | + names = [model.rvs_to_values[rv].name for rv in model.free_RVs] |
| 197 | + dict_prior = {k: np.stack(v) for k, v in zip(names, prior_values)} |
| 198 | + |
| 199 | + return cast(Dict[str, np.ndarray], dict_prior) |
192 | 200 |
|
193 | 201 | def _initialize_kernel(self):
|
194 | 202 | """Create variables and logp function necessary to run kernel
|
@@ -325,12 +333,11 @@ def _posterior_to_trace(self, chain=0) -> NDArray:
|
325 | 333 | for i in range(lenght_pos):
|
326 | 334 | value = []
|
327 | 335 | size = 0
|
328 |
| - for varname in varnames: |
329 |
| - shape, new_size = self.var_info[varname] |
| 336 | + for var in self.variables: |
| 337 | + shape, new_size = self.var_info[var.name] |
330 | 338 | var_samples = self.tempered_posterior[i][size : size + new_size]
|
331 | 339 | # Round discrete variable samples. The rounded values were the ones
|
332 | 340 | # actually used in the logp evaluations (see logp_forw)
|
333 |
| - var = self.model[varname] |
334 | 341 | if var.dtype in discrete_types:
|
335 | 342 | var_samples = np.round(var_samples).astype(var.dtype)
|
336 | 343 | value.append(var_samples.reshape(shape))
|
|
0 commit comments