diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 2c84fb6e3..c457329e2 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -10,7 +10,7 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.17.0 # CI was failing to resolve + - pymc>=5.19.1 # CI was failing to resolve - blackjax - scikit-learn - better_optimize>=0.0.10 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 2c84fb6e3..c457329e2 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,7 +10,7 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.17.0 # CI was failing to resolve + - pymc>=5.19.1 # CI was failing to resolve - blackjax - scikit-learn - better_optimize>=0.0.10 diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index d4cd94356..034ecb5b1 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -281,10 +281,20 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs): class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis): name = "discrete_markov_chain_gibbs_metropolis" - def __init__(self, vars, proposal="uniform", order="random", model=None): + def __init__( + self, + vars, + proposal="uniform", + order="random", + model=None, + initial_point=None, + compile_kwargs: dict | None = None, + **kwargs, + ): model = pm.modelcontext(model) vars = get_value_vars_from_user_vars(vars, model) - initial_point = model.initial_point() + if initial_point is None: + initial_point = model.initial_point() dimcats = [] # The above variable is a list of pairs (aggregate dimension, number @@ -332,7 +342,9 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): self.tune = True # We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic - ArrayStep.__init__(self, vars, [model.compile_logp()]) + if compile_kwargs is None: + compile_kwargs = {} + ArrayStep.__init__(self, vars, [model.compile_logp(**compile_kwargs)], **kwargs) @staticmethod def competence(var): diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 24a72c0f7..26b18bb08 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -391,14 +391,14 @@ def sample_laplace_posterior( else: info = mu.point_map_info - flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info] + flat_shapes = [size for _, _, size, _ in info] slices = [ slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) ] posterior_draws = [ posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) - for idx, (name, shape, dtype) in zip(slices, info) + for idx, (name, shape, _, dtype) in zip(slices, info) ] idata = laplace_draws_to_inferencedata(posterior_draws, model) diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 898db598b..7173bc417 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -206,13 +206,16 @@ def arviz_from_particles(model, particles): ------- """ n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] - by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)} + by_varname = { + k.name: v.squeeze()[np.newaxis, :].astype(k.dtype) + for k, v in zip(model.value_vars, particles) + } varnames = [v.name for v in model.value_vars] with model: strace = NDArray(name=model.name) strace.setup(n_particles, 0) for particle_index in range(0, n_particles): - strace.record(point={k: by_varname[k][0][particle_index] for k in varnames}) + strace.record(point={k: np.asarray(by_varname[k][0][particle_index]) for k in varnames}) multitrace = MultiTrace((strace,)) return to_inference_data(multitrace, log_likelihood=False) diff --git a/pymc_experimental/version.txt b/pymc_experimental/version.txt index b1e80bb24..845639eef 100644 --- a/pymc_experimental/version.txt +++ b/pymc_experimental/version.txt @@ -1 +1 @@ -0.1.3 +0.1.4 diff --git a/requirements.txt b/requirements.txt index b992ad37b..1051d2eeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.17.0 +pymc>=5.19.1 scikit-learn diff --git a/tests/distributions/test_discrete_markov_chain.py b/tests/distributions/test_discrete_markov_chain.py index 0d855ef44..28e9ed154 100644 --- a/tests/distributions/test_discrete_markov_chain.py +++ b/tests/distributions/test_discrete_markov_chain.py @@ -225,7 +225,7 @@ def test_change_size_univariate(self): def test_mcmc_sampling(self): with pm.Model(coords={"step": range(100)}) as model: init_dist = Categorical.dist(p=[0.5, 0.5]) - DiscreteMarkovChain( + markov_chain = DiscreteMarkovChain( "markov_chain", P=[[0.1, 0.9], [0.1, 0.9]], init_dist=init_dist, @@ -233,8 +233,10 @@ def test_mcmc_sampling(self): dims="step", ) - step_method = assign_step_methods(model) - assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis) + _, assigned_step_methods = assign_step_methods(model) + assert assigned_step_methods[DiscreteMarkovChainGibbsMetropolis] == [ + model.rvs_to_values[markov_chain] + ] # Sampler needs no tuning idata = pm.sample(