Skip to content

Commit c212b85

Browse files
committed
Allow Dim version of simple SymbolicRandomVariables
1 parent fe0a537 commit c212b85

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

pymc/dims/distributions/scalar.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
import pytensor.xtensor as ptx
1515
import pytensor.xtensor.random as pxr
1616

17+
from pytensor.xtensor import as_xtensor
18+
1719
from pymc.dims.distributions.core import (
1820
DimDistribution,
1921
PositiveDimDistribution,
2022
UnitDimDistribution,
2123
)
2224
from pymc.distributions.continuous import Beta as RegularBeta
2325
from pymc.distributions.continuous import Gamma as RegularGamma
26+
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
2427
from pymc.distributions.continuous import InverseGamma as RegularInverseGamma
25-
from pymc.distributions.continuous import flat, halfflat
2628

2729

2830
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
@@ -90,6 +92,21 @@ def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs):
9092
return super().dist([nu, mu, sigma], **kwargs)
9193

9294

95+
class HalfStudentT(PositiveDimDistribution):
96+
@classmethod
97+
def dist(cls, nu, sigma=None, *, lam=None, **kwargs):
98+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
99+
return super().dist([nu, sigma], **kwargs)
100+
101+
@classmethod
102+
def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None):
103+
nu = as_xtensor(nu)
104+
sigma = as_xtensor(sigma)
105+
core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op
106+
xop = pxr._as_xrv(core_rv)
107+
return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
108+
109+
93110
class Cauchy(DimDistribution):
94111
xrv_op = pxr.cauchy
95112

pymc/distributions/distribution.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytensor.graph.rewriting.basic import in2out
3333
from pytensor.graph.utils import MetaType
3434
from pytensor.tensor.basic import as_tensor_variable
35-
from pytensor.tensor.random.op import RandomVariable
35+
from pytensor.tensor.random.op import RandomVariable, RNGConsumerOp
3636
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
3737
from pytensor.tensor.random.utils import normalize_size_param
3838
from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -207,7 +207,7 @@ def __get__(self, owner_self, owner_cls):
207207
return self.fget(owner_self if owner_self is not None else owner_cls)
208208

209209

210-
class SymbolicRandomVariable(MeasurableOp, OpFromGraph):
210+
class SymbolicRandomVariable(MeasurableOp, RNGConsumerOp, OpFromGraph):
211211
"""Symbolic Random Variable.
212212
213213
This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic
@@ -294,7 +294,10 @@ def default_output(cls_or_self) -> int | None:
294294
@staticmethod
295295
def get_input_output_type_idxs(
296296
extended_signature: str | None,
297-
) -> tuple[tuple[tuple[int], int | None, tuple[int]], tuple[tuple[int], tuple[int]]]:
297+
) -> tuple[
298+
tuple[tuple[int, ...], int | None, tuple[int, ...]],
299+
tuple[tuple[int, ...], tuple[int, ...]],
300+
]:
298301
"""Parse extended_signature and return indexes for *[rng], [size] and parameters as well as outputs."""
299302
if extended_signature is None:
300303
raise ValueError("extended_signature must be provided")
@@ -367,8 +370,27 @@ def __init__(
367370

368371
kwargs.setdefault("inline", True)
369372
kwargs.setdefault("strict", True)
373+
# Many RVS have a size argument, even when this is `None` and is therefore unused
374+
kwargs.setdefault("on_unused_input", "ignore")
370375
super().__init__(*args, **kwargs)
371376

377+
def make_node(self, *inputs):
378+
# If we try to build the RV with a different size type (vector -> None or None -> vector)
379+
# We need to rebuild the Op with new size type in the inner graph
380+
if self.extended_signature is not None:
381+
(rng_arg_idxs, size_arg_idx, param_idxs), _ = self.get_input_output_type_idxs(
382+
self.extended_signature
383+
)
384+
if size_arg_idx is not None and len(rng_arg_idxs) == 1:
385+
new_size_type = normalize_size_param(inputs[size_arg_idx]).type
386+
if not self.input_types[size_arg_idx].in_same_class(new_size_type):
387+
params = [inputs[idx] for idx in param_idxs]
388+
size = inputs[size_arg_idx]
389+
rng = inputs[rng_arg_idxs[0]]
390+
return self.rebuild_rv(*params, size=size, rng=rng).owner
391+
392+
return super().make_node(*inputs)
393+
372394
def update(self, node: Apply) -> dict[Variable, Variable]:
373395
"""Symbolic update expression for input random state variables.
374396

tests/dims/distributions/test_scalar.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
HalfCauchy,
2323
HalfFlat,
2424
HalfNormal,
25+
HalfStudentT,
2526
InverseGamma,
2627
Laplace,
2728
LogNormal,
@@ -119,6 +120,22 @@ def test_studentt():
119120
assert_equivalent_logp_graph(model, reference_model)
120121

121122

123+
def test_halfstudentt():
124+
coords = {"a": range(3)}
125+
with Model(coords=coords) as model:
126+
HalfStudentT("x", nu=1, dims="a")
127+
HalfStudentT("y", nu=1, sigma=3, dims="a")
128+
HalfStudentT("z", nu=1, lam=3, dims="a")
129+
130+
with Model(coords=coords) as reference_model:
131+
regular_distributions.HalfStudentT("x", nu=1, dims="a")
132+
regular_distributions.HalfStudentT("y", nu=1, sigma=3, dims="a")
133+
regular_distributions.HalfStudentT("z", nu=1, lam=3, dims="a")
134+
135+
assert_equivalent_random_graph(model, reference_model)
136+
assert_equivalent_logp_graph(model, reference_model)
137+
138+
122139
def test_cauchy():
123140
coords = {"a": range(3)}
124141
with Model(coords=coords) as model:

0 commit comments

Comments
 (0)