Skip to content

Commit d52209c

Browse files
Add sign for bijective scalar transforms and generic cdf/icdf implementation for TransformedDistributions. (#1853)
1 parent e0d450b commit d52209c

File tree

5 files changed

+60
-18
lines changed

5 files changed

+60
-18
lines changed

numpyro/distributions/continuous.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,6 @@ def variance(self):
698698
a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2)
699699
return jnp.where(self.concentration <= 2, jnp.inf, a)
700700

701-
def cdf(self, x):
702-
return 1 - self.base_dist.cdf(1 / x)
703-
704701
def entropy(self):
705702
return (
706703
self.concentration
@@ -1205,9 +1202,6 @@ def mean(self):
12051202
def variance(self):
12061203
return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2)
12071204

1208-
def cdf(self, x):
1209-
return self.base_dist.cdf(jnp.log(x))
1210-
12111205
def entropy(self):
12121206
return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale)
12131207

@@ -1283,9 +1277,6 @@ def variance(self):
12831277
- self.mean**2
12841278
)
12851279

1286-
def cdf(self, x):
1287-
return self.base_dist.cdf(jnp.log(x))
1288-
12891280
def entropy(self):
12901281
log_low = jnp.log(self.low)
12911282
log_high = jnp.log(self.high)
@@ -2162,12 +2153,6 @@ def variance(self):
21622153
def support(self):
21632154
return constraints.greater_than(self.scale)
21642155

2165-
def cdf(self, value):
2166-
return 1 - jnp.power(self.scale / value, self.alpha)
2167-
2168-
def icdf(self, q):
2169-
return self.scale / jnp.power(1 - q, 1 / self.alpha)
2170-
21712156
def entropy(self):
21722157
return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha
21732158

numpyro/distributions/distribution.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,24 @@ def mean(self):
11051105
def variance(self):
11061106
raise NotImplementedError
11071107

1108+
def cdf(self, value):
1109+
sign = 1
1110+
for transform in reversed(self.transforms):
1111+
sign *= transform.sign
1112+
value = transform.inv(value)
1113+
q = self.base_dist.cdf(value)
1114+
return jnp.where(sign < 0, 1 - q, q)
1115+
1116+
def icdf(self, q):
1117+
sign = 1
1118+
for transform in self.transforms:
1119+
sign *= transform.sign
1120+
q = jnp.where(sign < 0, 1 - q, q)
1121+
value = self.base_dist.icdf(q)
1122+
for transform in self.transforms:
1123+
value = transform(value)
1124+
return value
1125+
11081126

11091127
class FoldedDistribution(TransformedDistribution):
11101128
"""

numpyro/distributions/transforms.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def inverse_shape(self, shape):
104104
"""
105105
return shape
106106

107+
@property
108+
def sign(self):
109+
"""
110+
Sign of the derivative of the transform if it is bijective.
111+
"""
112+
raise NotImplementedError(
113+
f"Transform `{self.__class__.__name__}` does not implement `sign`."
114+
)
115+
107116
# Allow for pickle serialization of transforms.
108117
def __getstate__(self):
109118
attrs = {}
@@ -147,6 +156,10 @@ def domain(self):
147156
def codomain(self):
148157
return self._inv.domain
149158

159+
@property
160+
def sign(self):
161+
return self._inv.sign
162+
150163
@property
151164
def inv(self):
152165
return self._inv
@@ -231,6 +244,10 @@ def codomain(self):
231244
else:
232245
raise NotImplementedError
233246

247+
@property
248+
def sign(self):
249+
return jnp.sign(self.scale)
250+
234251
def __call__(self, x):
235252
return self.loc + self.scale * x
236253

@@ -309,6 +326,13 @@ def codomain(self):
309326
self.parts[-1].codomain, output_event_dim - last_output_event_dim
310327
)
311328

329+
@property
330+
def sign(self):
331+
sign = 1
332+
for transform in self.parts:
333+
sign *= transform.sign
334+
return sign
335+
312336
def __call__(self, x):
313337
for part in self.parts:
314338
x = part(x)
@@ -509,6 +533,8 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
509533

510534

511535
class ExpTransform(Transform):
536+
sign = 1
537+
512538
# TODO: refine domain/codomain logic through setters, especially when
513539
# transforms for inverses are supported
514540
def __init__(self, domain=constraints.real):
@@ -550,6 +576,8 @@ def __eq__(self, other):
550576

551577

552578
class IdentityTransform(ParameterFreeTransform):
579+
sign = 1
580+
553581
def __call__(self, x):
554582
return x
555583

@@ -912,9 +940,14 @@ def __eq__(self, other):
912940
return False
913941
return jnp.array_equal(self.exponent, other.exponent)
914942

943+
@property
944+
def sign(self):
945+
return jnp.sign(self.exponent)
946+
915947

916948
class SigmoidTransform(ParameterFreeTransform):
917949
codomain = constraints.unit_interval
950+
sign = 1
918951

919952
def __call__(self, x):
920953
return _clipped_expit(x)
@@ -1006,6 +1039,7 @@ class SoftplusTransform(ParameterFreeTransform):
10061039

10071040
domain = constraints.real
10081041
codomain = constraints.softplus_positive
1042+
sign = 1
10091043

10101044
def __call__(self, x):
10111045
return softplus(x)
@@ -1177,6 +1211,7 @@ class ReshapeTransform(Transform):
11771211

11781212
domain = constraints.real
11791213
codomain = constraints.real
1214+
sign = 1
11801215

11811216
def __init__(self, forward_shape, inverse_shape) -> None:
11821217
forward_size = math.prod(forward_shape)

test/test_distributions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def get_sp_dist(jax_dist):
540540
T(dist.HalfNormal, 1.0),
541541
T(dist.HalfNormal, np.array([1.0, 2.0])),
542542
T(_ImproperWrapper, constraints.positive, (), (3,)),
543+
T(dist.InverseGamma, np.array([3.1]), np.array([[2.0], [3.0]])),
543544
T(dist.InverseGamma, np.array([1.7]), np.array([[2.0], [3.0]])),
544545
T(dist.InverseGamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
545546
T(dist.Kumaraswamy, 10.0, np.array([2.0, 3.0])),
@@ -1568,7 +1569,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
15681569
samples = d.sample(key=random.PRNGKey(0), sample_shape=(100,))
15691570
quantiles = random.uniform(random.PRNGKey(1), (100,) + d.shape())
15701571
try:
1571-
rtol = 2e-3 if jax_dist in (dist.Gamma, dist.StudentT) else 1e-5
1572+
rtol = 2e-3 if jax_dist in (dist.Gamma, dist.LogNormal, dist.StudentT) else 1e-5
15721573
if d.shape() == () and not d.is_discrete:
15731574
assert_allclose(
15741575
jax.vmap(jax.grad(d.cdf))(samples),
@@ -1585,7 +1586,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
15851586
assert_allclose(d.cdf(d.icdf(quantiles)), quantiles, atol=1e-5, rtol=1e-5)
15861587
assert_allclose(d.icdf(d.cdf(samples)), samples, atol=1e-5, rtol=rtol)
15871588
except NotImplementedError:
1588-
pass
1589+
pytest.skip("cdf/icdf not implemented")
15891590

15901591
# test against scipy
15911592
if not sp_dist:
@@ -1599,7 +1600,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
15991600
expected_icdf = sp_dist.ppf(quantiles)
16001601
assert_allclose(actual_icdf, expected_icdf, atol=1e-4, rtol=1e-4)
16011602
except NotImplementedError:
1602-
pass
1603+
pytest.skip("cdf/icdf not implemented")
16031604

16041605

16051606
@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DIRECTIONAL)

test/test_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ def test_bijective_transforms(transform, shape):
341341
)
342342
slogdet = jnp.linalg.slogdet(jac)
343343
assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol)
344+
assert transform.domain.event_dim or jnp.allclose(
345+
jnp.sign(jnp.diagonal(jac, axis1=-1, axis2=-2)), transform.sign
346+
)
344347

345348

346349
def test_batched_recursive_linear_transform():

0 commit comments

Comments
 (0)