Skip to content

Commit 9686a70

Browse files
authored
Merge pull request #219 from danielward27/better_errors
Better errors
2 parents a8ddd90 + 570fe26 commit 9686a70

File tree

5 files changed

+70
-17
lines changed

5 files changed

+70
-17
lines changed

flowjax/bijections/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,23 @@ def __init__(
293293
bijection: AbstractBijection,
294294
inverter: Callable[[AbstractBijection, Array, Array | None], Array],
295295
):
296+
@eqx.filter_custom_jvp
297+
def nondiff_inverter(bijection, y, condition):
298+
return inverter(bijection, y, condition)
299+
300+
@nondiff_inverter.def_jvp
301+
def nondiff_inverter_jvp(*args, **kwargs):
302+
raise RuntimeError(
303+
"Computing gradients through the numerical inverse would lead to "
304+
"misleading results. If you are using a flow with the analytical "
305+
"transform only defined in one direction, consider inverting the "
306+
"bijection by flipping the ``invert`` argument to the flow. If this is "
307+
"not possible, consider using implicit differentation (not yet "
308+
"supported)."
309+
)
310+
296311
self.bijection = bijection
297-
self.inverter = inverter
312+
self.inverter = nondiff_inverter
298313
self.shape = self.bijection.shape
299314
self.cond_shape = self.bijection.cond_shape
300315

flowjax/distributions.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,27 @@ class AbstractTransformed(AbstractDistribution):
244244
base_dist: AbstractVar[AbstractDistribution]
245245
bijection: AbstractVar[AbstractBijection]
246246

247+
def __check_init__(self):
248+
"""Check for compatible shapes between base_dist and bijection."""
249+
if (
250+
self.base_dist.cond_shape is not None
251+
and self.bijection.cond_shape is not None
252+
and self.base_dist.cond_shape != self.bijection.cond_shape
253+
):
254+
raise ValueError(
255+
"The base distribution and bijection are both conditional "
256+
"but have mismatched cond_shape attributes. Base distribution has"
257+
f"{self.base_dist.cond_shape}, and the bijection has"
258+
f"{self.bijection.cond_shape}.",
259+
)
260+
261+
if self.base_dist.shape != self.bijection.shape:
262+
raise ValueError(
263+
"The base distribution and bijection have mismatched shapes. "
264+
f"Base distribution has {self.base_dist.shape}, and the bijection "
265+
f"has {self.bijection.shape}.",
266+
)
267+
247268
def _log_prob(self, x, condition=None):
248269
z, log_abs_det = self.bijection.inverse_and_log_det(x, condition)
249270
p_z = self.base_dist._log_prob(z, condition)
@@ -268,20 +289,6 @@ def _sample_and_log_prob(
268289
)
269290
return sample, log_prob_base - forward_log_dets
270291

271-
def __check_init__(self): # TODO test errors and test conditional base distribution
272-
"""Checks cond_shape is compatible in both bijection and distribution."""
273-
if (
274-
self.base_dist.cond_shape is not None
275-
and self.bijection.cond_shape is not None
276-
and self.base_dist.cond_shape != self.bijection.cond_shape
277-
):
278-
raise ValueError(
279-
"The base distribution and bijection are both conditional "
280-
"but have mismatched cond_shape attributes. Base distribution has"
281-
f"{self.base_dist.cond_shape}, and the bijection has"
282-
f"{self.bijection.cond_shape}.",
283-
)
284-
285292
def merge_transforms(self):
286293
"""Unnests nested transformed distributions.
287294

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ license = { file = "LICENSE" }
2323
name = "flowjax"
2424
readme = "README.md"
2525
requires-python = ">=3.10"
26-
version = "17.1.1"
26+
version = "17.1.2"
2727

2828
[project.urls]
2929
repository = "https://github.com/danielward27/flowjax"

tests/test_bijections/test_bijection_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from functools import partial
2+
3+
import jax
14
import jax.numpy as jnp
25
import pytest
36
from equinox import EquinoxRuntimeError
47

5-
from flowjax.bijections import Affine, Indexed, Permute
8+
from flowjax.bijections import Affine, Indexed, NumericalInverse, Permute
9+
from flowjax.root_finding import bisection_search, root_finder_to_inverter
610

711
test_cases = {
812
# name: idx, expected
@@ -32,3 +36,20 @@ def test_partial(idx, expected):
3236
def test_Permute_argcheck():
3337
with pytest.raises(EquinoxRuntimeError):
3438
Permute(jnp.array([0, 0]))
39+
40+
41+
test_cases = [jax.grad, jax.jacfwd, jax.jacrev]
42+
43+
44+
@pytest.mark.parametrize("diff_fn", test_cases)
45+
def test_NumericalInverse_not_differentiable(diff_fn):
46+
bijection = NumericalInverse(
47+
Affine(5, 2),
48+
root_finder_to_inverter(
49+
partial(bisection_search, lower=-1, upper=1, atol=1e-7),
50+
),
51+
)
52+
with pytest.raises(
53+
RuntimeError, match="Computing gradients through the numerical inverse"
54+
):
55+
diff_fn(bijection.inverse)(jnp.ones(()))

tests/test_distributions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def test_uniform_params():
120120

121121
class _TestDist(AbstractDistribution):
122122
"Toy distribution object, for testing of distribution broadcasting."
123+
123124
shape: tuple[int, ...]
124125
cond_shape: tuple[int, ...] | None
125126

@@ -292,3 +293,12 @@ def test_transformed():
292293
assert dist.sample(jr.key(0), condition=jnp.ones((5, 2))).shape == (5,)
293294
assert dist.shape == ()
294295
assert dist.cond_shape == (2,)
296+
297+
298+
def test_transformed_wrong_shape():
299+
300+
with pytest.raises(ValueError, match="mismatched shapes"):
301+
Transformed(
302+
StandardNormal((2,)),
303+
Affine(),
304+
)

0 commit comments

Comments
 (0)