Skip to content

Commit 3a8d898

Browse files
committed
Handle single output and fix transform
1 parent 126e76b commit 3a8d898

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

pymc/distributions/distribution.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,14 @@ def dist(
415415
@_get_measurable_outputs.register(SymbolicRandomVariable)
416416
def _get_measurable_outputs_symbolic_random_variable(op, node):
417417
# This tells Aeppl that any non RandomType outputs are measurable
418+
419+
# Assume that if there is one default_output, that's the only one that is measurable
420+
# In the rare case this is not what one wants, a specialized _get_measuarable_outputs
421+
# can dispatch for a subclassed Op
422+
if op.default_output is not None:
423+
return [node.default_output()]
424+
425+
# Otherwise assume that any outputs that are not of RandomType are measurable
418426
return [out for out in node.outputs if not isinstance(out.type, RandomType)]
419427

420428

pymc/distributions/multivariate.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,8 +2482,6 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
24822482
ndim_supp=zerosum_axes,
24832483
)
24842484

2485-
# print(f"{support_shape.eval() = }")
2486-
24872485
if support_shape is None:
24882486
if zerosum_axes > 0:
24892487
raise ValueError("You must specify shape or support_shape parameter")

pymc/distributions/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def __init__(self, zerosum_axes):
291291
By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed
292292
on the last axis.
293293
"""
294-
self.zerosum_axes = zerosum_axes
294+
self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes)
295295

296296
def forward(self, value, *rv_inputs):
297297
for axis in self.zerosum_axes:

0 commit comments

Comments
 (0)