Skip to content

Commit ad35577

Browse files
authored
Add folded distribution (#1050)
1 parent 05d37d6 commit ad35577

File tree

4 files changed

+61
-1
lines changed

4 files changed

+61
-1
lines changed

docs/source/distributions.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ ExpandedDistribution
1717
:show-inheritance:
1818
:member-order: bysource
1919

20+
FoldedDistribution
21+
------------------
22+
.. autoclass:: numpyro.distributions.distribution.FoldedDistribution
23+
:members:
24+
:undoc-members:
25+
:show-inheritance:
26+
:member-order: bysource
27+
2028
ImproperUniform
2129
---------------
2230
.. autoclass:: numpyro.distributions.distribution.ImproperUniform

numpyro/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
Delta,
6565
Distribution,
6666
ExpandedDistribution,
67+
FoldedDistribution,
6768
ImproperUniform,
6869
Independent,
6970
MaskedDistribution,
@@ -109,6 +110,7 @@
109110
"Distribution",
110111
"Exponential",
111112
"ExpandedDistribution",
113+
"FoldedDistribution",
112114
"Gamma",
113115
"GammaPoisson",
114116
"GaussianRandomWalk",

numpyro/distributions/distribution.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535

3636
from jax import lax, tree_util
3737
import jax.numpy as jnp
38+
from jax.scipy.special import logsumexp
3839

39-
from numpyro.distributions.transforms import ComposeTransform, Transform
40+
from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform
4041
from numpyro.distributions.util import (
4142
lazy_property,
4243
promote_shapes,
@@ -1051,6 +1052,38 @@ def tree_flatten(self):
10511052
)
10521053

10531054

1055+
class FoldedDistribution(TransformedDistribution):
1056+
"""
1057+
Equivalent to ``TransformedDistribution(base_dist, AbsTransform())``,
1058+
but additionally supports :meth:`log_prob` .
1059+
1060+
:param Distribution base_dist: A univariate distribution to reflect.
1061+
"""
1062+
1063+
support = constraints.positive
1064+
1065+
def __init__(self, base_dist, validate_args=None):
1066+
if base_dist.event_shape:
1067+
raise ValueError("Only univariate distributions can be folded.")
1068+
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
1069+
1070+
@validate_sample
1071+
def log_prob(self, value):
1072+
dim = max(len(self.batch_shape), jnp.ndim(value))
1073+
plus_minus = jnp.array([1.0, -1.0]).reshape((2,) + (1,) * dim)
1074+
return logsumexp(self.base_dist.log_prob(plus_minus * value), axis=0)
1075+
1076+
def tree_flatten(self):
1077+
base_flatten, base_aux = self.base_dist.tree_flatten()
1078+
return base_flatten, (type(self.base_dist), base_aux)
1079+
1080+
@classmethod
1081+
def tree_unflatten(cls, aux_data, params):
1082+
base_cls, base_aux = aux_data
1083+
base_dist = base_cls.tree_unflatten(base_aux, params)
1084+
return cls(base_dist)
1085+
1086+
10541087
class Delta(Distribution):
10551088
arg_constraints = {
10561089
"v": constraints.dependent(is_discrete=False),

test/test_distributions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ def __init__(self, rate, *, validate_args=None):
109109
super().__init__(rate, is_sparse=True, validate_args=validate_args)
110110

111111

112+
class FoldedNormal(dist.FoldedDistribution):
113+
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
114+
115+
def __init__(self, loc, scale, validate_args=None):
116+
self.loc = loc
117+
self.scale = scale
118+
super().__init__(dist.Normal(loc, scale), validate_args=validate_args)
119+
120+
@classmethod
121+
def tree_unflatten(cls, aux_data, params):
122+
return dist.FoldedDistribution.tree_unflatten(aux_data, params)
123+
124+
112125
_DIST_MAP = {
113126
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
114127
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
@@ -189,6 +202,8 @@ def get_sp_dist(jax_dist):
189202
T(dist.Gumbel, 0.0, 1.0),
190203
T(dist.Gumbel, 0.5, 2.0),
191204
T(dist.Gumbel, jnp.array([0.0, 0.5]), jnp.array([1.0, 2.0])),
205+
T(FoldedNormal, 2.0, 4.0),
206+
T(FoldedNormal, jnp.array([2.0, 50.0]), jnp.array([4.0, 100.0])),
192207
T(dist.HalfCauchy, 1.0),
193208
T(dist.HalfCauchy, jnp.array([1.0, 2.0])),
194209
T(dist.HalfNormal, 1.0),
@@ -1102,6 +1117,8 @@ def fn(*args):
11021117
def test_mean_var(jax_dist, sp_dist, params):
11031118
if jax_dist is _ImproperWrapper:
11041119
pytest.skip("Improper distribution does not has mean/var implemented")
1120+
if jax_dist is FoldedNormal:
1121+
pytest.skip("Folded distribution does not has mean/var implemented")
11051122
if jax_dist in (
11061123
_TruncatedNormal,
11071124
dist.LeftTruncatedDistribution,

0 commit comments

Comments
 (0)