|
| 1 | +<<<<<<< HEAD |
1 | 2 | from typing import List
|
2 | 3 |
|
3 | 4 | try:
|
@@ -138,3 +139,79 @@ def _distr_parameters_for_repr(self):
|
138 | 139 |
|
139 | 140 | def logcdf(self, value):
|
140 | 141 | raise NotImplementedError()
|
| 142 | +======= |
| 143 | +import pymc3 as pm |
| 144 | +import numpy as np |
| 145 | +import pandas as pd |
| 146 | +from typing import * |
| 147 | +import aesara |
| 148 | +import aesara.tensor as aet |
| 149 | + |
| 150 | + |
| 151 | +def ZeroSumNormal( |
| 152 | + name: str, |
| 153 | + sigma: Optional[float] = None, |
| 154 | + *, |
| 155 | + dims: Union[str, Tuple[str]], |
| 156 | + model: Optional[pm.Model] = None, |
| 157 | +): |
| 158 | + """ |
| 159 | + Multivariate normal, such that sum(x, axis=-1) = 0. |
| 160 | +
|
| 161 | + Parameters |
| 162 | + ---------- |
| 163 | + name: str |
| 164 | + String name representation of the PyMC variable. |
| 165 | + sigma: Optional[float], defaults to None |
| 166 | + Scale for the Normal distribution. If ``None``, a standard Normal is used. |
| 167 | + dims: Union[str, Tuple[str]] |
| 168 | + Dimension names for the shape of the distribution. |
| 169 | + See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example. |
| 170 | + model: Optional[pm.Model], defaults to None |
| 171 | + PyMC model instance. If ``None``, a model instance is created. |
| 172 | + """ |
| 173 | + if isinstance(dims, str): |
| 174 | + dims = (dims,) |
| 175 | + |
| 176 | + model = pm.modelcontext(model) |
| 177 | + *dims_pre, dim = dims |
| 178 | + dim_trunc = f"{dim}_truncated_" |
| 179 | + (shape,) = model.shape_from_dims((dim,)) |
| 180 | + assert shape >= 1 |
| 181 | + |
| 182 | + model.add_coords({f"{dim}_truncated_": pd.RangeIndex(shape - 1)}) |
| 183 | + raw = pm.Normal(f"{name}_truncated_", dims=tuple(dims_pre) + (dim_trunc,), sigma=sigma) |
| 184 | + Q = make_sum_zero_hh(shape) |
| 185 | + draws = aet.dot(raw, Q[:, 1:].T) |
| 186 | + |
| 187 | + #if sigma is not None: |
| 188 | + # draws = sigma * draws |
| 189 | + |
| 190 | + return pm.Deterministic(name, draws, dims=dims) |
| 191 | + |
| 192 | + |
| 193 | + |
| 194 | +def make_sum_zero_hh(N: int) -> np.ndarray: |
| 195 | + """ |
| 196 | + Build a householder transformation matrix that maps e_1 to a vector of all 1s. |
| 197 | + """ |
| 198 | + e_1 = np.zeros(N) |
| 199 | + e_1[0] = 1 |
| 200 | + a = np.ones(N) |
| 201 | + a /= np.sqrt(a @ a) |
| 202 | + v = a + e_1 |
| 203 | + v /= np.sqrt(v @ v) |
| 204 | + return np.eye(N) - 2 * np.outer(v, v) |
| 205 | + |
| 206 | +def make_sum_zero_hh(N: int) -> np.ndarray: |
| 207 | + """ |
| 208 | + Build a householder transformation matrix that maps e_1 to a vector of all 1s. |
| 209 | + """ |
| 210 | + e_1 = np.zeros(N) |
| 211 | + e_1[0] = 1 |
| 212 | + a = np.ones(N) |
| 213 | + a /= np.sqrt(a @ a) |
| 214 | + v = a + e_1 |
| 215 | + v /= np.sqrt(v @ v) |
| 216 | + return np.eye(N) - 2 * np.outer(v, v) |
| 217 | +>>>>>>> 2da3052 (ZeroSumNormal: initial commit) |
0 commit comments