diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 82eca936b..78eb3f7bb 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -979,7 +979,7 @@ def constant_fold( """ fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True) - # The default rewrite_graph includes a constand_folding that is not always applied. + # The default rewrite_graph includes a constant_folding that is not always applied. # We use an unconditional constant_folding as the last pass to ensure a thorough constant folding. rewrite_graph(fg) topo_unconditional_constant_folding.apply(fg) diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index 163cec472..34d30cfa5 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -40,6 +40,9 @@ def make_node(self, rv, *total_size): out = rv.type() return Apply(self, [rv, *total_size], [out]) + def infer_shape(self, fgraph, node, shapes): + return [shapes[0]] + def perform(self, node, inputs, output_storage): output_storage[0][0] = inputs[0] diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0903d2775..deedfc8d9 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -74,6 +74,7 @@ from pymc.pytensorf import ( SeedSequenceSeed, compile, + constant_fold, find_rng_nodes, reseed_rngs, ) @@ -1105,7 +1106,10 @@ def symbolic_normalizing_constant(self): t = self.to_flat_input( pt.max( [ - get_scaling(v.owner.inputs[1:], v.shape) + get_scaling( + v.owner.inputs[1:], + constant_fold([v.owner.inputs[0].shape], raise_not_constant=False), + ) for v in self.group if isinstance(v.owner.op, MinibatchRandomVariable) ] @@ -1272,7 +1276,10 @@ def symbolic_normalizing_constant(self): t = pt.max( self.collect("symbolic_normalizing_constant") + [ - get_scaling(obs.owner.inputs[1:], obs.shape) + get_scaling( + obs.owner.inputs[1:], + constant_fold([obs.owner.inputs[0].shape], raise_not_constant=False), + ) for obs in self.model.observed_RVs if isinstance(obs.owner.op, MinibatchRandomVariable) ] diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index d692b3001..0f40572f7 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -20,6 +20,7 @@ import pymc as pm +from pymc.testing import assert_no_rvs from pymc.variational import opvi from pymc.variational.approximations import ( Empirical, @@ -278,3 +279,18 @@ def test_logq_globals(three_var_approx): es = symbolic_logq.eval() assert e.shape == () assert es.shape == (2,) + + +def test_symbolic_normalizing_constant_no_rvs(): + # Test that RVs aren't included in the graph of symbolic_normalizing_constant + rng = np.random.default_rng() + + with pm.Model() as m: + obs = pm.Data("obs", rng.normal(size=(1000,))) + obs_batch = pm.Minibatch(obs, batch_size=128) + x = pm.Normal("x") # Need at least one Free_RV in the graph + y_hat = pm.Flat("y_hat", observed=obs_batch, total_size=1000) + + step = pm.ADVI() + + assert_no_rvs(step.approx.symbolic_normalizing_constant)