Skip to content

Commit c1be6bc

Browse files
committed
Allow Dim version of simple SymbolicRandomVariables
1 parent 9a04236 commit c1be6bc

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

pymc/dims/distributions/scalar.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
import pytensor.xtensor as ptx
1515
import pytensor.xtensor.random as pxr
1616

17+
from pytensor.xtensor import as_xtensor
18+
1719
from pymc.dims.distributions.core import (
1820
DimDistribution,
1921
PositiveDimDistribution,
2022
UnitDimDistribution,
2123
)
2224
from pymc.distributions.continuous import Beta as RegularBeta
2325
from pymc.distributions.continuous import Gamma as RegularGamma
24-
from pymc.distributions.continuous import flat, halfflat
26+
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
2527

2628

2729
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
@@ -89,6 +91,21 @@ def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs):
8991
return super().dist([nu, mu, sigma], **kwargs)
9092

9193

94+
class HalfStudentT(PositiveDimDistribution):
95+
@classmethod
96+
def dist(cls, nu, sigma=None, *, lam=None, **kwargs):
97+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
98+
return super().dist([nu, sigma], **kwargs)
99+
100+
@classmethod
101+
def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None):
102+
nu = as_xtensor(nu)
103+
sigma = as_xtensor(sigma)
104+
core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op
105+
xop = pxr.as_xrv(core_rv)
106+
return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
107+
108+
92109
class Cauchy(DimDistribution):
93110
xrv_op = pxr.cauchy
94111

pymc/distributions/distribution.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,29 @@ def __init__(
370370

371371
kwargs.setdefault("inline", True)
372372
kwargs.setdefault("strict", True)
373+
# Many RVS have a size argument, even when this is `None` and is therefore unused
374+
kwargs.setdefault("on_unused_input", "ignore")
375+
if hasattr(self, "name"):
376+
kwargs.setdefault("name", self.name)
373377
super().__init__(*args, **kwargs)
374378

379+
def make_node(self, *inputs):
380+
# If we try to build the RV with a different size type (vector -> None or None -> vector)
381+
# We need to rebuild the Op with new size type in the inner graph
382+
if self.extended_signature is not None:
383+
(rng_arg_idxs, size_arg_idx, param_idxs), _ = self.get_input_output_type_idxs(
384+
self.extended_signature
385+
)
386+
if size_arg_idx is not None and len(rng_arg_idxs) == 1:
387+
new_size_type = normalize_size_param(inputs[size_arg_idx]).type
388+
if not self.input_types[size_arg_idx].in_same_class(new_size_type):
389+
params = [inputs[idx] for idx in param_idxs]
390+
size = inputs[size_arg_idx]
391+
rng = inputs[rng_arg_idxs[0]]
392+
return self.rebuild_rv(*params, size=size, rng=rng).owner
393+
394+
return super().make_node(*inputs)
395+
375396
def update(self, node: Apply) -> dict[Variable, Variable]:
376397
"""Symbolic update expression for input random state variables.
377398

tests/dims/distributions/test_scalar.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
HalfCauchy,
2323
HalfFlat,
2424
HalfNormal,
25+
HalfStudentT,
2526
InverseGamma,
2627
Laplace,
2728
LogNormal,
@@ -119,6 +120,22 @@ def test_studentt():
119120
assert_equivalent_logp_graph(model, reference_model)
120121

121122

123+
def test_halfstudentt():
124+
coords = {"a": range(3)}
125+
with Model(coords=coords) as model:
126+
HalfStudentT("x", nu=1, dims="a")
127+
HalfStudentT("y", nu=1, sigma=3, dims="a")
128+
HalfStudentT("z", nu=1, lam=3, dims="a")
129+
130+
with Model(coords=coords) as reference_model:
131+
regular_distributions.HalfStudentT("x", nu=1, dims="a")
132+
regular_distributions.HalfStudentT("y", nu=1, sigma=3, dims="a")
133+
regular_distributions.HalfStudentT("z", nu=1, lam=3, dims="a")
134+
135+
assert_equivalent_random_graph(model, reference_model)
136+
assert_equivalent_logp_graph(model, reference_model)
137+
138+
122139
def test_cauchy():
123140
coords = {"a": range(3)}
124141
with Model(coords=coords) as model:

0 commit comments

Comments
 (0)