-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Added ZeroSumNormal Distribution #4776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -65,6 +65,7 @@ | |||||||||||||||||
"Lognormal", | ||||||||||||||||||
"ChiSquared", | ||||||||||||||||||
"HalfNormal", | ||||||||||||||||||
"ZeroSumNormal", | ||||||||||||||||||
"Wald", | ||||||||||||||||||
"Pareto", | ||||||||||||||||||
"InverseGamma", | ||||||||||||||||||
|
@@ -924,6 +925,67 @@ def logcdf(self, value): | |||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class ZeroSumNormal(Continuous): | ||||||||||||||||||
def __init__(self, sigma=1, zerosum_dims=None, zerosum_axes=None, **kwargs): | ||||||||||||||||||
shape = kwargs.get("shape", ()) | ||||||||||||||||||
dims = kwargs.get("dims", None) | ||||||||||||||||||
if isinstance(shape, int): | ||||||||||||||||||
shape = (shape,) | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(dims, str): | ||||||||||||||||||
dims = (dims,) | ||||||||||||||||||
|
||||||||||||||||||
self.mu = self.median = self.mode = tt.zeros(shape) | ||||||||||||||||||
self.sigma = tt.as_tensor_variable(sigma) | ||||||||||||||||||
|
||||||||||||||||||
if zerosum_dims is None and zerosum_axes is None: | ||||||||||||||||||
if shape: | ||||||||||||||||||
zerosum_axes = (-1,) | ||||||||||||||||||
else: | ||||||||||||||||||
zerosum_axes = () | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(zerosum_axes, int): | ||||||||||||||||||
zerosum_axes = (zerosum_axes,) | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(zerosum_dims, str): | ||||||||||||||||||
zerosum_dims = (zerosum_dims,) | ||||||||||||||||||
|
||||||||||||||||||
if zerosum_axes is not None and zerosum_dims is not None: | ||||||||||||||||||
raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.") | ||||||||||||||||||
|
||||||||||||||||||
if zerosum_dims is not None: | ||||||||||||||||||
if dims is None: | ||||||||||||||||||
raise ValueError("zerosum_dims can only be used with the dims kwargs.") | ||||||||||||||||||
zerosum_axes = [] | ||||||||||||||||||
for dim in zerosum_dims: | ||||||||||||||||||
zerosum_axes.append(dims.index(dim)) | ||||||||||||||||||
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enforcing positive axis here leads to problems when you draw samples from the prior predictive. It's better to replace this line with this
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
if "transform" not in kwargs or kwargs["transform"] == None: | ||||||||||||||||||
kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes) | ||||||||||||||||||
|
||||||||||||||||||
super().__init__(**kwargs) | ||||||||||||||||||
|
||||||||||||||||||
def logp(self, value): | ||||||||||||||||||
return Normal.dist(sigma=self.sigma).logp(value) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don’t add the scaling of sigma here, our There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How so ? Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logp is somewhat strange still. From the math side this should be pm.MvNormal with cov=I - J / n where J is a matrix of all 1s. We don't want to write it like this though, because we don't want to do matrix factorization, and pm.MvNormal doesn't work if an eigenvalues is 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can avoid the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I came across this wiki section that talks about the degenerate MvNormal case (which is what we have with the ZeroSumNormal). We could use that formula as the expected logp value and test if the def pseudo_log_det(A, tol=1e-13):
v, w = np.linalg.eigh(A)
return np.sum(np.log(np.where(np.abs(v) >= tol, v, 1)), axis=-1)
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
psdet = 0.5 * pseudo_log_det(2 * np.pi * cov)
exp = 0.5 * (value[..., None, :] @ np.linalg.pinv(cov) @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran a few tests with the logp and it looks like the def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
v, w = np.linalg.eigh(cov)
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
cov_pinv = w[:, 1:] @ np.diag(1 / v[1:]) @ w[:, 1:].T
exp = 0.5 * (value[..., None, :] @ cov_pinv @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf) This is different from the psdet = n * (0.5 * np.log(2 * np.pi) + np.log(np.sigma)) This means that we have to multiply the |
||||||||||||||||||
|
||||||||||||||||||
def _random(self, scale, size): | ||||||||||||||||||
samples = stats.norm.rvs(loc=0, scale=scale, size=size) | ||||||||||||||||||
for axis in self.zerosum_axes: | ||||||||||||||||||
samples -= np.mean(samples, axis=axis, keepdims=True) | ||||||||||||||||||
return samples | ||||||||||||||||||
|
||||||||||||||||||
def random(self, point=None, size=None): | ||||||||||||||||||
(sigma,) = draw_values([self.sigma], point=point, size=size) | ||||||||||||||||||
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size) | ||||||||||||||||||
|
||||||||||||||||||
def _distr_parameters_for_repr(self): | ||||||||||||||||||
return ["sigma"] | ||||||||||||||||||
|
||||||||||||||||||
def logcdf(self, value): | ||||||||||||||||||
raise NotImplementedError() | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class Wald(PositiveContinuous): | ||||||||||||||||||
r""" | ||||||||||||||||||
Wald log-likelihood. | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,7 +98,7 @@ def __new__(cls, name, *args, **kwargs): | |
raise TypeError("observed needs to be data but got: {}".format(type(data))) | ||
total_size = kwargs.pop("total_size", None) | ||
|
||
dims = kwargs.pop("dims", None) | ||
dims = kwargs["dims"] if "dims" in kwargs else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reverted that change, as we're now using |
||
has_shape = "shape" in kwargs | ||
shape = kwargs.pop("shape", None) | ||
if dims is not None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes no sense to have a
ZeroSumNormal
whenshape=()
orNone
. In that case, the RV should also be exactly equal to zero. I think that we should test ifshape is None or len(shape) == 0
and raise aValueError
in that case. Something that says,ZeroSumNormal
is defined only for RVs that are not scalar.