Skip to content

Commit 61eb3f8

Browse files
committed
ZeroSumNormal: initial commit
1 parent 5d6af4d commit 61eb3f8

File tree

2 files changed

+1622
-1969
lines changed

2 files changed

+1622
-1969
lines changed

examples/generalized_linear_models/GLM-ZeroSumNormal.ipynb

Lines changed: 1545 additions & 1969 deletions
Large diffs are not rendered by default.

examples/generalized_linear_models/ZeroSumNormal.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
<<<<<<< HEAD
12
from typing import List
23

34
try:
@@ -138,3 +139,79 @@ def _distr_parameters_for_repr(self):
138139

139140
def logcdf(self, value):
140141
raise NotImplementedError()
142+
=======
143+
import pymc3 as pm
144+
import numpy as np
145+
import pandas as pd
146+
from typing import *
147+
import aesara
148+
import aesara.tensor as aet
149+
150+
151+
def ZeroSumNormal(
152+
name: str,
153+
sigma: Optional[float] = None,
154+
*,
155+
dims: Union[str, Tuple[str]],
156+
model: Optional[pm.Model] = None,
157+
):
158+
"""
159+
Multivariate normal, such that sum(x, axis=-1) = 0.
160+
161+
Parameters
162+
----------
163+
name: str
164+
String name representation of the PyMC variable.
165+
sigma: Optional[float], defaults to None
166+
Scale for the Normal distribution. If ``None``, a standard Normal is used.
167+
dims: Union[str, Tuple[str]]
168+
Dimension names for the shape of the distribution.
169+
See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
170+
model: Optional[pm.Model], defaults to None
171+
PyMC model instance. If ``None``, a model instance is created.
172+
"""
173+
if isinstance(dims, str):
174+
dims = (dims,)
175+
176+
model = pm.modelcontext(model)
177+
*dims_pre, dim = dims
178+
dim_trunc = f"{dim}_truncated_"
179+
(shape,) = model.shape_from_dims((dim,))
180+
assert shape >= 1
181+
182+
model.add_coords({f"{dim}_truncated_": pd.RangeIndex(shape - 1)})
183+
raw = pm.Normal(f"{name}_truncated_", dims=tuple(dims_pre) + (dim_trunc,), sigma=sigma)
184+
Q = make_sum_zero_hh(shape)
185+
draws = aet.dot(raw, Q[:, 1:].T)
186+
187+
#if sigma is not None:
188+
# draws = sigma * draws
189+
190+
return pm.Deterministic(name, draws, dims=dims)
191+
192+
193+
194+
def make_sum_zero_hh(N: int) -> np.ndarray:
195+
"""
196+
Build a householder transformation matrix that maps e_1 to a vector of all 1s.
197+
"""
198+
e_1 = np.zeros(N)
199+
e_1[0] = 1
200+
a = np.ones(N)
201+
a /= np.sqrt(a @ a)
202+
v = a + e_1
203+
v /= np.sqrt(v @ v)
204+
return np.eye(N) - 2 * np.outer(v, v)
205+
206+
def make_sum_zero_hh(N: int) -> np.ndarray:
207+
"""
208+
Build a householder transformation matrix that maps e_1 to a vector of all 1s.
209+
"""
210+
e_1 = np.zeros(N)
211+
e_1[0] = 1
212+
a = np.ones(N)
213+
a /= np.sqrt(a @ a)
214+
v = a + e_1
215+
v /= np.sqrt(v @ v)
216+
return np.eye(N) - 2 * np.outer(v, v)
217+
>>>>>>> 2da3052 (ZeroSumNormal: initial commit)

0 commit comments

Comments
 (0)