Skip to content

Commit 68ad9c0

Browse files
committed
Allow Dim version of simple SymbolicRandomVariables
1 parent 27555e3 commit 68ad9c0

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

pymc/dims/distributions/scalar.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
import pytensor.xtensor as ptx
1515
import pytensor.xtensor.random as pxr
1616

17+
from pytensor.xtensor import as_xtensor
18+
1719
from pymc.dims.distributions.core import (
1820
DimDistribution,
1921
PositiveDimDistribution,
2022
UnitDimDistribution,
2123
)
2224
from pymc.distributions.continuous import Beta as RegularBeta
2325
from pymc.distributions.continuous import Gamma as RegularGamma
24-
from pymc.distributions.continuous import flat, halfflat
26+
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
2527

2628

2729
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
@@ -89,6 +91,21 @@ def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs):
8991
return super().dist([nu, mu, sigma], **kwargs)
9092

9193

94+
class HalfStudentT(PositiveDimDistribution):
95+
@classmethod
96+
def dist(cls, nu, sigma=None, *, lam=None, **kwargs):
97+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
98+
return super().dist([nu, sigma], **kwargs)
99+
100+
@classmethod
101+
def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None):
102+
nu = as_xtensor(nu)
103+
sigma = as_xtensor(sigma)
104+
core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op
105+
xop = pxr._as_xrv(core_rv)
106+
return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
107+
108+
92109
class Cauchy(DimDistribution):
93110
xrv_op = pxr.cauchy
94111

pymc/distributions/distribution.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,27 @@ def __init__(
370370

371371
kwargs.setdefault("inline", True)
372372
kwargs.setdefault("strict", True)
373+
# Many RVS have a size argument, even when this is `None` and is therefore unused
374+
kwargs.setdefault("on_unused_input", "ignore")
373375
super().__init__(*args, **kwargs)
374376

377+
def make_node(self, *inputs):
378+
# If we try to build the RV with a different size type (vector -> None or None -> vector)
379+
# We need to rebuild the Op with new size type in the inner graph
380+
if self.extended_signature is not None:
381+
(rng_arg_idxs, size_arg_idx, param_idxs), _ = self.get_input_output_type_idxs(
382+
self.extended_signature
383+
)
384+
if size_arg_idx is not None and len(rng_arg_idxs) == 1:
385+
new_size_type = normalize_size_param(inputs[size_arg_idx]).type
386+
if not self.input_types[size_arg_idx].in_same_class(new_size_type):
387+
params = [inputs[idx] for idx in param_idxs]
388+
size = inputs[size_arg_idx]
389+
rng = inputs[rng_arg_idxs[0]]
390+
return self.rebuild_rv(*params, size=size, rng=rng).owner
391+
392+
return super().make_node(*inputs)
393+
375394
def update(self, node: Apply) -> dict[Variable, Variable]:
376395
"""Symbolic update expression for input random state variables.
377396

tests/dims/distributions/test_scalar.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
HalfCauchy,
2323
HalfFlat,
2424
HalfNormal,
25+
HalfStudentT,
2526
InverseGamma,
2627
Laplace,
2728
LogNormal,
@@ -119,6 +120,22 @@ def test_studentt():
119120
assert_equivalent_logp_graph(model, reference_model)
120121

121122

123+
def test_halfstudentt():
124+
coords = {"a": range(3)}
125+
with Model(coords=coords) as model:
126+
HalfStudentT("x", nu=1, dims="a")
127+
HalfStudentT("y", nu=1, sigma=3, dims="a")
128+
HalfStudentT("z", nu=1, lam=3, dims="a")
129+
130+
with Model(coords=coords) as reference_model:
131+
regular_distributions.HalfStudentT("x", nu=1, dims="a")
132+
regular_distributions.HalfStudentT("y", nu=1, sigma=3, dims="a")
133+
regular_distributions.HalfStudentT("z", nu=1, lam=3, dims="a")
134+
135+
assert_equivalent_random_graph(model, reference_model)
136+
assert_equivalent_logp_graph(model, reference_model)
137+
138+
122139
def test_cauchy():
123140
coords = {"a": range(3)}
124141
with Model(coords=coords) as model:

0 commit comments

Comments
 (0)