Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.17.0 # CI was failing to resolve
- pymc>=5.19.1 # CI was failing to resolve
- blackjax
- scikit-learn
- better_optimize>=0.0.10
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.17.0 # CI was failing to resolve
- pymc>=5.19.1 # CI was failing to resolve
- blackjax
- scikit-learn
- better_optimize>=0.0.10
18 changes: 15 additions & 3 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,20 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
name = "discrete_markov_chain_gibbs_metropolis"

def __init__(self, vars, proposal="uniform", order="random", model=None):
def __init__(
self,
vars,
proposal="uniform",
order="random",
model=None,
initial_point=None,
compile_kwargs: dict | None = None,
**kwargs,
):
model = pm.modelcontext(model)
vars = get_value_vars_from_user_vars(vars, model)
initial_point = model.initial_point()
if initial_point is None:
initial_point = model.initial_point()

dimcats = []
# The above variable is a list of pairs (aggregate dimension, number
Expand Down Expand Up @@ -332,7 +342,9 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
self.tune = True

# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
ArrayStep.__init__(self, vars, [model.compile_logp()])
if compile_kwargs is None:
compile_kwargs = {}
ArrayStep.__init__(self, vars, [model.compile_logp(**compile_kwargs)], **kwargs)

@staticmethod
def competence(var):
Expand Down
4 changes: 2 additions & 2 deletions pymc_experimental/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,14 @@ def sample_laplace_posterior(

else:
info = mu.point_map_info
flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info]
flat_shapes = [size for _, _, size, _ in info]
slices = [
slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
]

posterior_draws = [
posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
for idx, (name, shape, dtype) in zip(slices, info)
for idx, (name, shape, _, dtype) in zip(slices, info)
]

idata = laplace_draws_to_inferencedata(posterior_draws, model)
Expand Down
7 changes: 5 additions & 2 deletions pymc_experimental/inference/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,16 @@ def arviz_from_particles(model, particles):
-------
"""
n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)}
by_varname = {
k.name: v.squeeze()[np.newaxis, :].astype(k.dtype)
for k, v in zip(model.value_vars, particles)
}
varnames = [v.name for v in model.value_vars]
with model:
strace = NDArray(name=model.name)
strace.setup(n_particles, 0)
for particle_index in range(0, n_particles):
strace.record(point={k: by_varname[k][0][particle_index] for k in varnames})
strace.record(point={k: np.asarray(by_varname[k][0][particle_index]) for k in varnames})
multitrace = MultiTrace((strace,))
return to_inference_data(multitrace, log_likelihood=False)

Expand Down
2 changes: 1 addition & 1 deletion pymc_experimental/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.3
0.1.4
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pymc>=5.17.0
pymc>=5.19.1
scikit-learn
8 changes: 5 additions & 3 deletions tests/distributions/test_discrete_markov_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,18 @@ def test_change_size_univariate(self):
def test_mcmc_sampling(self):
with pm.Model(coords={"step": range(100)}) as model:
init_dist = Categorical.dist(p=[0.5, 0.5])
DiscreteMarkovChain(
markov_chain = DiscreteMarkovChain(
"markov_chain",
P=[[0.1, 0.9], [0.1, 0.9]],
init_dist=init_dist,
shape=(100,),
dims="step",
)

step_method = assign_step_methods(model)
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)
_, assigned_step_methods = assign_step_methods(model)
assert assigned_step_methods[DiscreteMarkovChainGibbsMetropolis] == [
model.rvs_to_values[markov_chain]
]

# Sampler needs no tuning
idata = pm.sample(
Expand Down
Loading