Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymc_extras/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pymc import SymbolicRandomVariable
from pymc.model.fgraph import ModelVar
from pymc.variational.minibatch_rv import MinibatchRandomVariable
from pytensor.graph import Variable, ancestors
from pytensor.graph.basic import io_toposort
from pytensor.tensor import TensorType, TensorVariable
Expand Down Expand Up @@ -140,6 +141,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
elif isinstance(node.op, ModelVar):
var_dims[node.outputs[0]] = inputs_dims[0]

elif isinstance(node.op, MinibatchRandomVariable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit I kind of tried to order the if elif branches by likelihood, in which case I would put this last

var_dims[node.outputs[0]] = inputs_dims[0]

elif isinstance(node.op, DimShuffle):
[input_dims] = inputs_dims
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
Expand Down
8 changes: 8 additions & 0 deletions tests/model/marginal/test_graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from pymc.distributions import CustomDist
from pymc.variational.minibatch_rv import create_minibatch_rv
from pytensor.tensor.type_other import NoneTypeT

from pymc_extras.model.marginal.graph_analysis import (
Expand Down Expand Up @@ -160,6 +161,13 @@ def test_random_variable(self):
with pytest.raises(ValueError, match="Use of known dimensions"):
subgraph_batch_dim_connection(inp, [invalid_out])

def test_minibatched_random_variable(self):
inp = pt.tensor(shape=(4, 3, 2))
out1 = pt.random.normal(loc=inp)
out2 = create_minibatch_rv(out1, total_size=(10, 10, 10))
[dims1] = subgraph_batch_dim_connection(inp, [out2])
assert dims1 == (0, 1, 2)

def test_symbolic_random_variable(self):
inp = pt.tensor(shape=(4, 3, 2))

Expand Down