Skip to content

Commit d072b94

Browse files
committed
Update SymbolicRandomVariables to manage RNGs explicitly
* Also fixes error when jaxifying MarginalModel logp
1 parent a05001b commit d072b94

File tree

4 files changed

+40
-21
lines changed

4 files changed

+40
-21
lines changed

pymc_experimental/distributions/timeseries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,13 @@ def transition(*args):
202202
discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
203203

204204
discrete_mc_op = DiscreteMarkovChainRV(
205-
inputs=[P_, steps_, init_dist_],
205+
inputs=[P_, steps_, init_dist_, state_rng],
206206
outputs=[state_next_rng, discrete_mc_],
207207
ndim_supp=1,
208208
n_lags=n_lags,
209209
)
210210

211-
discrete_mc = discrete_mc_op(P, steps, init_dist)
211+
discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
212212
return discrete_mc
213213

214214

pymc_experimental/model/marginal_model.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def transform_input(inputs):
410410
marginalized_rv.type, dependent_logps
411411
)
412412

413-
rv_shape = constant_fold(tuple(marginalized_rv.shape))
413+
rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
414414
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
415415
rv_domain_tensor = pt.moveaxis(
416416
pt.full(
@@ -579,6 +579,15 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
579579
return True
580580

581581

582+
from pytensor.graph.basic import graph_inputs
583+
584+
585+
def collect_shared_vars(outputs, blockers):
586+
return [
587+
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
588+
]
589+
590+
582591
def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
583592
# TODO: This should eventually be integrated in a more general routine that can
584593
# identify other types of supported marginalization, of which finite discrete
@@ -621,27 +630,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
621630
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]
622631

623632
outputs = rvs_to_marginalize
624-
# Clone replace inner RV rng inputs so that we can be sure of the update order
625-
# replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()}
626-
# Clone replace outter RV inputs, so that their shared RNGs don't make it into
627-
# the inner graph of the marginalized RVs
628-
# FIXME: This shouldn't be needed!
629-
replace_inputs = {}
630-
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
631-
cloned_outputs = clone_replace(outputs, replace=replace_inputs)
633+
# We are strict about shared variables in SymbolicRandomVariables
634+
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)
632635

633636
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
634637
marginalize_constructor = DiscreteMarginalMarkovChainRV
635638
else:
636639
marginalize_constructor = FiniteDiscreteMarginalRV
637640

638641
marginalization_op = marginalize_constructor(
639-
inputs=list(replace_inputs.values()),
640-
outputs=cloned_outputs,
642+
inputs=inputs,
643+
outputs=outputs,
641644
ndim_supp=ndim_supp,
642645
)
643646

644-
marginalized_rvs = marginalization_op(*replace_inputs.keys())
647+
marginalized_rvs = marginalization_op(*inputs)
645648
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
646649
return rvs_to_marginalize, marginalized_rvs
647650

pymc_experimental/statespace/filters/distributions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,12 @@ def step_fn(*args):
193193

194194
(ss_rng,) = tuple(updates.values())
195195
linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
196-
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps_],
196+
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps_, rng],
197197
outputs=[ss_rng, statespace_],
198198
ndim_supp=1,
199199
)
200200

201-
linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps)
201+
linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps, rng)
202202
return linear_gaussian_ss
203203

204204

@@ -354,10 +354,10 @@ def step(mu, cov, rng):
354354
(seq_mvn_rng,) = tuple(updates.values())
355355

356356
mvn_seq_op = KalmanFilterRV(
357-
inputs=[mus_, covs_, logp_, steps_], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
357+
inputs=[mus_, covs_, logp_, steps_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
358358
)
359359

360-
mvn_seq = mvn_seq_op(mus, covs, logp, steps)
360+
mvn_seq = mvn_seq_op(mus, covs, logp, steps, rng)
361361
return mvn_seq
362362

363363

pymc_experimental/tests/model/test_marginal_model.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def test_marginalized_bernoulli_logp():
6060
[idx, y],
6161
ndim_supp=0,
6262
n_updates=0,
63-
)(
64-
mu
65-
)[0].owner
63+
# Ignore the fact we didn't specify shared RNG input/outputs for idx,y
64+
strict=False,
65+
)(mu)[0].owner
6666

6767
y_vv = y.clone()
6868
(logp,) = _logprob(
@@ -758,3 +758,19 @@ def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2):
758758
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
759759
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
760760
np.testing.assert_allclose(logp_fn(test_point), expected_logp)
761+
762+
763+
def test_mutable_indexing_jax_backend():
764+
pytest.importorskip("jax")
765+
from pymc.sampling.jax import get_jaxified_logp
766+
767+
with MarginalModel() as model:
768+
data = pm.Data(f"data", np.zeros(10))
769+
770+
cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
771+
cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5))
772+
773+
is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
774+
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
775+
model.marginalize(["is_outlier"])
776+
get_jaxified_logp(model)

0 commit comments

Comments
 (0)