Skip to content

Commit 8849837

Browse files
LarsKuestefanradev93
authored andcommitted
Add Rational Quadratic Spline Transforms to Normalizing Flows (#291)
* Splines draft * update keras requirement * small improvements to error messages * add rq spline function * add spline transform * update searchsorted utils for jax also add padd util * update tests * add assert_allclose util for improved messages * parametrize transform for flow tests * update jacobian, jacobian trace, vjp, jvp, and corresponding usages and tests * fix imports, remove old jacobian and jvp, fix application in free form flow * improve logdet computation in free form flows * Fix comparison for symbolic tensors under tf * Add splines to twomoons notebook * improve pad utility * fix missing left edge in spline * fix inside mask edge case * explicitly set bias initializer * add better expand utility * small clean up, renaming * fix indexing, fix inside check * dump * fix sign of log jacobian for inverse pass in rq spline * fix parameter splitting for spline transform * improve readability * fix scale and shift trailing dimension * fix inverse pass return value * correctly choose bins once for each dimension, even for multi-dimensional inputs * run formatter * reduce searchsorted log spam * log backend used at setup * remove maximum message cache size * Improve warning message for jax searchsorted * Fix spline parameter binning for compiled contexts * update inverse transform same as forward * Update TwoMoons notebook with splines WIP [skip ci] * fix spline inverse call for out of bounds values * Add working splines --------- Co-authored-by: stefanradev93 <[email protected]>
1 parent ef3892e commit 8849837

File tree

33 files changed

+1033
-762
lines changed

33 files changed

+1033
-762
lines changed

bayesflow/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def setup():
3636

3737
torch.autograd.set_grad_enabled(False)
3838

39+
from bayesflow.utils import logging
40+
41+
logging.info(f"Using backend {keras.backend.backend()!r}")
42+
3943

4044
# call and clean up namespace
4145
setup()

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def calibration_ecdf(
176176
titles = ["Stacked ECDFs"]
177177

178178
for ax, title in zip(plot_data["axes"].flat, titles):
179-
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
179+
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
180180
ax.legend(fontsize=legend_fontsize)
181181
ax.set_title(title, fontsize=title_fontsize)
182182

bayesflow/diagnostics/plots/mmd_hypothesis_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs):
7979

8080
mmd_critical = ops.quantile(mmd_null, 1 - alpha_level)
8181
fill_area_under_kde(
82-
kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area"
82+
kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level * 100)}% rejection area"
8383
)
8484

8585
if truncate_v_lines_at_kde:

bayesflow/networks/consistency_models/continuous_consistency_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def f_teacher(x, t):
249249
ops.cos(t) * ops.sin(t) * self.sigma_data,
250250
)
251251

252-
teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents)
252+
teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents, return_output=True)
253253
teacher_output = ops.stop_gradient(teacher_output)
254254
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)
255255

bayesflow/networks/coupling_flow/couplings/single_coupling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import keras
2-
32
from keras.saving import register_keras_serializable as serializable
43

54
from bayesflow.types import Tensor
@@ -24,6 +23,7 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
2423

2524
output_projector_kwargs = kwargs.get("output_projector_kwargs", {})
2625
output_projector_kwargs.setdefault("kernel_initializer", "zeros")
26+
output_projector_kwargs.setdefault("bias_initializer", "zeros")
2727
self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs)
2828

2929
# serialization: store all parameters necessary to call __init__
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import TypedDict
2+
3+
import keras
4+
5+
from bayesflow.types import Tensor
6+
7+
8+
class Edges(TypedDict):
9+
left: Tensor
10+
right: Tensor
11+
bottom: Tensor
12+
top: Tensor
13+
14+
15+
class Derivatives(TypedDict):
16+
left: Tensor
17+
right: Tensor
18+
19+
20+
def _rational_quadratic_spline(
21+
x: Tensor, edges: Edges, derivatives: Derivatives, inverse: bool = False
22+
) -> (Tensor, Tensor):
23+
# rename variables to match the paper:
24+
25+
# $x^{(k)}$
26+
xk = edges["left"]
27+
28+
# $x^{(k+1)}$
29+
xkp = edges["right"]
30+
31+
# $y^{(k)}$
32+
yk = edges["bottom"]
33+
34+
# $y^{(k+1)}$
35+
ykp = edges["top"]
36+
37+
# $delta^{(k)}$
38+
dk = derivatives["left"]
39+
40+
# $delta^{(k+1)}$
41+
dkp = derivatives["right"]
42+
43+
# commonly used values
44+
dx = xkp - xk
45+
dy = ykp - yk
46+
sk = dy / dx
47+
48+
if not inverse:
49+
xi = (x - xk) / dx
50+
51+
# Eq. 4 in the paper
52+
numerator = dy * (sk * xi**2 + dk * xi * (1 - xi))
53+
denominator = sk + (dkp + dk - 2 * sk) * xi * (1 - xi)
54+
result = yk + numerator / denominator
55+
else:
56+
# rename for clarity
57+
y = x
58+
59+
# Eq. 6-8 in the paper
60+
a = dy * (sk - dk) + (y - yk) * (dkp + dk - 2 * sk)
61+
b = dy * dk - (y - yk) * (dkp + dk - 2 * sk)
62+
c = -sk * (y - yk)
63+
64+
# Eq. 29 in the appendix of the paper
65+
discriminant = b**2 - 4 * a * c
66+
67+
# the discriminant must be positive, even when the spline is called out of bounds
68+
discriminant = keras.ops.maximum(discriminant, 0)
69+
70+
xi = 2 * c / (-b - keras.ops.sqrt(discriminant))
71+
result = xi * dx + xk
72+
73+
# Eq 5 in the paper
74+
numerator = sk**2 * (dkp * xi**2 + 2 * sk * xi * (1 - xi) + dk * (1 - xi) ** 2)
75+
denominator = (sk + (dkp + dk - 2 * sk) * xi * (1 - xi)) ** 2
76+
log_jac = keras.ops.log(numerator) - keras.ops.log(denominator)
77+
78+
if inverse:
79+
log_jac = -log_jac
80+
81+
return result, log_jac

0 commit comments

Comments
 (0)