Skip to content

Commit 0537f2a

Browse files
vpratzLarsKue
andauthored
Add free-form flows as inference networks (#251)
* feat: add free-form flows as inference networks * implements the fff loss * still missing: calculation of the log probability * fff: add log jacobian determinant computation * util: make vjp globally accessible Change `torch.autograd.functional.vjp` to `torch.func.vjp` as the former implementation broke gradient flow. It then also uses the same API as Jax, making the code easier to parse. * utils: change autograd backend for torch jvp Change from `torch.autograd.functional.jvp` to `torch.func.jvp`, as recommended in the documentation. https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html Using autograd.functional seems to break the gradient flow, while `func` does not produce problems in this regard. * fff: use vjp and jvp from utils * improve docs and type hints * fix vjp call in fff * add fff to tests, remove flow matching from global tests * fix default kwargs for fff subnets * remove double source attribution * improve type hints * adjust batch_wrap to handle non-iterable arguments * fff: handle conditions=None --------- Co-authored-by: LarsKue <[email protected]>
1 parent 7f37aef commit 0537f2a

File tree

12 files changed

+423
-74
lines changed

12 files changed

+423
-74
lines changed

bayesflow/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .coupling_flow import CouplingFlow
44
from .deep_set import DeepSet
55
from .flow_matching import FlowMatching
6+
from .free_form_flow import FreeFormFlow
67
from .inference_network import InferenceNetwork
78
from .mlp import MLP
89
from .lstnet import LSTNet
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .free_form_flow import FreeFormFlow
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import keras
2+
from keras import ops
3+
from keras.saving import register_keras_serializable as serializable
4+
5+
from bayesflow.types import Tensor
6+
from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp
7+
8+
from ..inference_network import InferenceNetwork
9+
10+
11+
@serializable(package="networks.free_form_flow")
12+
class FreeFormFlow(InferenceNetwork):
13+
"""Implements a dimensionality-preserving Free-form Flow.
14+
Incorporates ideas from [1-2].
15+
16+
[1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F
17+
ree-form flows: Make Any Architecture a Normalizing Flow.
18+
In International Conference on Artificial Intelligence and Statistics.
19+
20+
[2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., &
21+
Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows.
22+
In International Conference on Learning Representations.
23+
"""
24+
25+
def __init__(
26+
self,
27+
beta: float = 50.0,
28+
encoder_subnet: str | type = "mlp",
29+
decoder_subnet: str | type = "mlp",
30+
base_distribution: str = "normal",
31+
hutchinson_sampling: str = "qr",
32+
**kwargs,
33+
):
34+
"""Creates an instance of a Free-form Flow.
35+
36+
Parameters:
37+
-----------
38+
beta : float, optional, default: 50.0
39+
encoder_subnet : str or type, optional, default: "mlp"
40+
A neural network type for the flow, will be instantiated using
41+
encoder_subnet_kwargs. Will be equipped with a projector to ensure
42+
the correct output dimension and a global skip connection.
43+
decoder_subnet : str or type, optional, default: "mlp"
44+
A neural network type for the flow, will be instantiated using
45+
decoder_subnet_kwargs. Will be equipped with a projector to ensure
46+
the correct output dimension and a global skip connection.
47+
base_distribution : str, optional, default: "normal"
48+
The latent distribution
49+
hutchinson_sampling : str, optional, default: "qr
50+
One of `["sphere", "qr"]`. Select the sampling scheme for the
51+
vectors of the Hutchinson trace estimator.
52+
**kwargs : dict, optional, default: {}
53+
Additional keyword arguments
54+
"""
55+
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
56+
self.encoder_subnet = find_network(encoder_subnet, **kwargs.get("encoder_subnet_kwargs", {}))
57+
self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
58+
self.decoder_subnet = find_network(decoder_subnet, **kwargs.get("decoder_subnet_kwargs", {}))
59+
self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
60+
61+
self.hutchinson_sampling = hutchinson_sampling
62+
self.beta = beta
63+
64+
self.seed_generator = keras.random.SeedGenerator()
65+
66+
# noinspection PyMethodOverriding
67+
def build(self, xz_shape, conditions_shape=None):
68+
super().build(xz_shape)
69+
self.encoder_projector.units = xz_shape[-1]
70+
self.decoder_projector.units = xz_shape[-1]
71+
72+
# construct input shape for subnet and subnet projector
73+
input_shape = list(xz_shape)
74+
75+
if conditions_shape is not None:
76+
input_shape[-1] += conditions_shape[-1]
77+
78+
input_shape = tuple(input_shape)
79+
80+
self.encoder_subnet.build(input_shape)
81+
self.decoder_subnet.build(input_shape)
82+
83+
input_shape = self.encoder_subnet.compute_output_shape(input_shape)
84+
self.encoder_projector.build(input_shape)
85+
86+
input_shape = self.decoder_subnet.compute_output_shape(input_shape)
87+
self.decoder_projector.build(input_shape)
88+
89+
def _forward(
90+
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
91+
) -> Tensor | tuple[Tensor, Tensor]:
92+
if density:
93+
if conditions is None:
94+
# None cannot be batched, so supply as keyword argument
95+
z, log_det = log_jacobian_determinant(x, self.encode, conditions=None, training=training, **kwargs)
96+
else:
97+
# conditions should be batched, supply as positional argument
98+
z, log_det = log_jacobian_determinant(x, self.encode, conditions, training=training, **kwargs)
99+
100+
log_density = self.base_distribution.log_prob(z) + log_det
101+
return z, log_density
102+
103+
z = self.encode(x, conditions, training=training, **kwargs)
104+
return z
105+
106+
def _inverse(
107+
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
108+
) -> Tensor | tuple[Tensor, Tensor]:
109+
if density:
110+
if conditions is None:
111+
# None cannot be batched, so supply as keyword argument
112+
x, log_det = log_jacobian_determinant(z, self.decode, conditions=None, training=training, **kwargs)
113+
else:
114+
# conditions should be batched, supply as positional argument
115+
x, log_det = log_jacobian_determinant(z, self.decode, conditions, training=training, **kwargs)
116+
log_density = self.base_distribution.log_prob(z) - log_det
117+
return x, log_density
118+
119+
x = self.decode(z, conditions, training=training, **kwargs)
120+
return x
121+
122+
def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
123+
if conditions is None:
124+
inp = x
125+
else:
126+
inp = concatenate(x, conditions, axis=-1)
127+
network_out = self.encoder_projector(
128+
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
129+
)
130+
return network_out + x
131+
132+
def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
133+
if conditions is None:
134+
inp = z
135+
else:
136+
inp = concatenate(z, conditions, axis=-1)
137+
network_out = self.decoder_projector(
138+
self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
139+
)
140+
return network_out + z
141+
142+
def _sample_v(self, x):
143+
batch_size = ops.shape(x)[0]
144+
total_dim = ops.shape(x)[-1]
145+
match self.hutchinson_sampling:
146+
case "qr":
147+
# Use QR decomposition as described in [2]
148+
v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator)
149+
q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x))
150+
v = q * ops.sqrt(total_dim)
151+
case "sphere":
152+
# Sample from sphere with radius sqrt(total_dim), as implemented in [1]
153+
v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator)
154+
v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True))
155+
case _:
156+
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
157+
return v
158+
159+
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
160+
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
161+
# sample random vector
162+
v = self._sample_v(x)
163+
164+
def encode(x):
165+
return self.encode(x, conditions, training=stage == "training")
166+
167+
def decode(z):
168+
return self.decode(z, conditions, training=stage == "training")
169+
170+
# VJP computation
171+
z, vjp_fn = vjp(encode, x)
172+
v1 = vjp_fn(v)[0]
173+
# JVP computation
174+
x_pred, v2 = jvp(decode, (z,), (v,))
175+
176+
# equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0]
177+
surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1)
178+
nll = -self.base_distribution.log_prob(z)
179+
maximum_likelihood_loss = nll - surrogate
180+
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
181+
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss)
182+
183+
return base_metrics | {"loss": loss}

bayesflow/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
parse_bytes,
2626
)
2727
from .jacobian_trace import jacobian_trace
28+
from .jacobian import compute_jacobian, log_jacobian_determinant
2829
from .jvp import jvp
30+
from .vjp import vjp
2931
from .optimal_transport import optimal_transport
3032
from .tensor_utils import (
3133
expand_left,

bayesflow/utils/jacobian.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from collections.abc import Callable
2+
import keras
3+
from keras import ops
4+
from bayesflow.types import Tensor
5+
6+
from functools import partial, wraps
7+
8+
9+
def compute_jacobian(
10+
x_in: Tensor,
11+
fn: Callable,
12+
*func_args: any,
13+
grad_type: str = "backward",
14+
**func_kwargs: any,
15+
) -> tuple[Tensor, Tensor]:
16+
"""Computes the Jacobian of a function with respect to its input.
17+
18+
:param x_in: The input tensor to compute the jacobian at.
19+
Shape: (batch_size, in_dim).
20+
:param fn: The function to compute the jacobian of, which transforms
21+
`x` to `fn(x)` of shape (batch_size, out_dim).
22+
:param func_args: The positional arguments to pass to the function.
23+
func_args are batched over the first dimension.
24+
:param grad_type: The type of gradient to use. Either 'backward' or
25+
'forward'.
26+
:param func_kwargs: The keyword arguments to pass to the function.
27+
func_kwargs are not batched.
28+
:return: The output of the function `fn(x)` and the jacobian
29+
of the function with respect to its input `x` of shape
30+
(batch_size, out_dim, in_dim)."""
31+
32+
def batch_wrap(fn: Callable) -> Callable:
33+
"""Add a batch dimension to each tensor argument.
34+
35+
:param fn:
36+
:return: wrapped function"""
37+
38+
def deep_unsqueeze(arg):
39+
if ops.is_tensor(arg):
40+
return arg[None, ...]
41+
elif isinstance(arg, dict):
42+
return {key: deep_unsqueeze(value) for key, value in arg.items()}
43+
elif isinstance(arg, (list, tuple)):
44+
return [deep_unsqueeze(value) for value in arg]
45+
raise ValueError(f"Argument cannot be batched: {arg}")
46+
47+
@wraps(fn)
48+
def wrapper(*args, **kwargs):
49+
args = deep_unsqueeze(args)
50+
return fn(*args, **kwargs)[0]
51+
52+
return wrapper
53+
54+
def double_output(fn):
55+
@wraps(fn)
56+
def wrapper(*args, **kwargs):
57+
out = fn(*args, **kwargs)
58+
return out, out
59+
60+
return wrapper
61+
62+
match keras.backend.backend():
63+
case "torch":
64+
import torch
65+
from torch.func import jacrev, jacfwd, vmap
66+
67+
jacfn = jacrev if grad_type == "backward" else jacfwd
68+
with torch.inference_mode(False):
69+
with torch.no_grad():
70+
fn_kwargs_prefilled = partial(fn, **func_kwargs)
71+
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled)
72+
fn_return_val = double_output(fn_batch_expanded)
73+
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True))
74+
jac, x_out = fn_jac_batched(x_in, *func_args)
75+
case "jax":
76+
from jax import jacrev, jacfwd, vmap
77+
78+
jacfn = jacrev if grad_type == "backward" else jacfwd
79+
fn_kwargs_prefilled = partial(fn, **func_kwargs)
80+
fn_batch_expanded = batch_wrap(fn_kwargs_prefilled)
81+
fn_return_val = double_output(fn_batch_expanded)
82+
fn_jac_batched = vmap(jacfn(fn_return_val, has_aux=True))
83+
jac, x_out = fn_jac_batched(x_in, *func_args)
84+
case "tensorflow":
85+
if grad_type == "forward":
86+
raise NotImplementedError("For TensorFlow, only backward mode Jacobian computation is available.")
87+
import tensorflow as tf
88+
89+
with tf.GradientTape() as tape:
90+
tape.watch(x_in)
91+
x_out = fn(x_in, *func_args, **func_kwargs)
92+
jac = tape.batch_jacobian(x_out, x_in)
93+
94+
case _:
95+
raise NotImplementedError(f"compute_jacobian not implemented for {keras.backend.backend()}.")
96+
return x_out, jac
97+
98+
99+
def log_jacobian_determinant(
100+
x_in: Tensor,
101+
fn: Callable,
102+
*func_args: any,
103+
grad_type: str = "backward",
104+
**func_kwargs: any,
105+
) -> tuple[Tensor, Tensor]:
106+
"""Computes the log Jacobian determinant of a function
107+
with respect to its input.
108+
109+
:param x_in: The input tensor to compute the jacobian at.
110+
Shape: (batch_size, in_dim).
111+
:param fn: The function to compute the jacobian of, which transforms
112+
`x` to `fn(x)` of shape (batch_size, out_dim).
113+
:param func_args: The positional arguments to pass to the function.
114+
func_args are batched over the first dimension.
115+
:param grad_type: The type of gradient to use. Either 'backward' or
116+
'forward'.
117+
:param func_kwargs: The keyword arguments to pass to the function.
118+
func_kwargs are not batched.
119+
:return: The output of the function `fn(x)` and the log jacobian determinant
120+
of the function with respect to its input `x` of shape
121+
(batch_size, out_dim, in_dim)."""
122+
123+
x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs)
124+
jac = ops.reshape(
125+
jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:])))
126+
)
127+
log_det = ops.slogdet(jac)[1]
128+
129+
return x_out, log_det

bayesflow/utils/jacobian_trace/_vjp.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)