|
14 | 14 | import pytensor.xtensor as ptx
|
15 | 15 | import pytensor.xtensor.random as ptxr
|
16 | 16 |
|
| 17 | +from pytensor.tensor import as_tensor |
| 18 | +from pytensor.xtensor import as_xtensor |
17 | 19 | from pytensor.xtensor import random as pxr
|
18 | 20 |
|
19 | 21 | from pymc.dims.distributions.core import VectorDimDistribution
|
| 22 | +from pymc.dims.distributions.transforms import ZeroSumTransform |
| 23 | +from pymc.distributions.multivariate import ZeroSumNormalRV |
| 24 | +from pymc.util import UNSET |
20 | 25 |
|
21 | 26 |
|
22 | 27 | class Categorical(VectorDimDistribution):
|
@@ -114,3 +119,80 @@ def dist(cls, mu, cov=None, *, chol=None, lower=True, core_dims=None, **kwargs):
|
114 | 119 | cov = chol.dot(chol.rename({d0: safe_name}), dim=d1).rename({safe_name: d1})
|
115 | 120 |
|
116 | 121 | return super().dist([mu, cov], core_dims=core_dims, **kwargs)
|
| 122 | + |
| 123 | + |
| 124 | +class ZeroSumNormal(VectorDimDistribution): |
| 125 | + """Zero-sum multivariate normal distribution. |
| 126 | +
|
| 127 | + Parameters |
| 128 | + ---------- |
| 129 | + sigma : xtensor_like, optional |
| 130 | + The standard deviation of the underlying unconstrained normal distribution. |
| 131 | + Defaults to 1.0. It cannot have core dimensions. |
| 132 | + core_dims : Sequence of str, optional |
| 133 | + The axes along which the zero-sum constraint is applied. |
| 134 | + **kwargs |
| 135 | + Additional keyword arguments used to define the distribution. |
| 136 | +
|
| 137 | + Returns |
| 138 | + ------- |
| 139 | + XTensorVariable |
| 140 | + An xtensor variable representing the zero-sum multivariate normal distribution. |
| 141 | + """ |
| 142 | + |
| 143 | + @classmethod |
| 144 | + def __new__( |
| 145 | + cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs |
| 146 | + ): |
| 147 | + if core_dims is not None: |
| 148 | + if isinstance(core_dims, str): |
| 149 | + core_dims = (core_dims,) |
| 150 | + |
| 151 | + # Create default_transform |
| 152 | + if observed is None and default_transform is UNSET: |
| 153 | + default_transform = ZeroSumTransform(dims=core_dims) |
| 154 | + |
| 155 | + # If the user didn't specify dims, take it from core_dims |
| 156 | + # We need them to be forwarded to dist in the `dim_lenghts` argument |
| 157 | + if dims is None and core_dims is not None: |
| 158 | + dims = (..., *core_dims) |
| 159 | + |
| 160 | + return super().__new__( |
| 161 | + *args, |
| 162 | + core_dims=core_dims, |
| 163 | + dims=dims, |
| 164 | + default_transform=default_transform, |
| 165 | + observed=observed, |
| 166 | + **kwargs, |
| 167 | + ) |
| 168 | + |
| 169 | + @classmethod |
| 170 | + def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs): |
| 171 | + if isinstance(core_dims, str): |
| 172 | + core_dims = (core_dims,) |
| 173 | + if core_dims is None or len(core_dims) == 0: |
| 174 | + raise ValueError("ZeroSumNormal requires atleast 1 core_dims") |
| 175 | + |
| 176 | + support_dims = as_xtensor( |
| 177 | + as_tensor([dim_lengths[core_dim] for core_dim in core_dims]), dims=("_",) |
| 178 | + ) |
| 179 | + sigma = cls._as_xtensor(sigma) |
| 180 | + |
| 181 | + return super().dist( |
| 182 | + [sigma, support_dims], core_dims=core_dims, dim_lengths=dim_lengths, **kwargs |
| 183 | + ) |
| 184 | + |
| 185 | + @classmethod |
| 186 | + def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None): |
| 187 | + sigma = as_xtensor(sigma) |
| 188 | + support_dims = as_xtensor(support_dims, dims=("_",)) |
| 189 | + support_shape = support_dims.values |
| 190 | + core_rv = ZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op |
| 191 | + xop = pxr.as_xrv( |
| 192 | + core_rv, |
| 193 | + core_inps_dims_map=[(), (0,)], |
| 194 | + core_out_dims_map=tuple(range(1, len(core_dims) + 1)), |
| 195 | + ) |
| 196 | + # Dummy "_" core dim to absorb the support_shape vector |
| 197 | + # If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed |
| 198 | + return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng) |
0 commit comments