Skip to content

Commit af19085

Browse files
committed
Enable interleaved flows
1 parent 6c0c153 commit af19085

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

bayesflow/inference_networks.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
class InvertibleNetwork(tf.keras.Model):
3131
"""Implements a chain of conditional invertible coupling layers for conditional density estimation."""
3232

33+
available_designs = ("affine", "spline", "interleaved")
34+
3335
def __init__(
3436
self,
3537
num_params,
@@ -69,16 +71,17 @@ def __init__(
6971
num_params : int
7072
The number of parameters to perform inference on. Equivalently, the dimensionality of the
7173
latent space.
72-
num_coupling_layers : int, optional, default: 5
74+
num_coupling_layers : int, optional, default: 6
7375
The number of coupling layers to use as defined in [1] and [2]. In general, more coupling layers
7476
will give you more expressive power, but will be slower and may need more simulations to train.
7577
Typically, between 4 and 10 coupling layers should suffice for most applications.
7678
coupling_design : str or callable, optional, default: 'affine'
77-
The type of internal coupling network to use. Must be in ['affine', 'spline'].
78-
The former corresponds to the architecture in [3, 5], the latter corresponds to a modified
79-
version of [4].
79+
The type of internal coupling network to use. Must be in ['affine', 'spline', 'interleaved'].
80+
The first corresponds to the architecture in [3, 5], the second corresponds to a modified
81+
version of [4]. The third option will alternate between affine and spline layers, for example,
82+
if num_coupling_layers == 3, the chain will consist of ["affine", "spline", "affine"] layers.
8083
81-
In general, spline couplings run slower than affine couplings, but require fewers coupling
84+
In general, spline couplings run slower than affine couplings, but require fewer coupling
8285
layers. Spline couplings may work best with complex (e.g., multimodal) low-dimensional
8386
problems. The difference will become less and less pronounced as we move to higher dimensions.
8487
@@ -127,16 +130,15 @@ def __init__(
127130

128131
super().__init__(**kwargs)
129132

130-
settings = dict(
133+
layer_settings = dict(
131134
latent_dim=num_params,
132-
coupling_settings=coupling_settings,
133-
coupling_design=coupling_design,
134135
permutation=permutation,
135136
use_act_norm=use_act_norm,
136137
act_norm_init=act_norm_init,
137138
)
138-
self.coupling_layers = []
139-
self.coupling_layers = [CouplingLayer(**settings) for _ in range(num_coupling_layers)]
139+
self.coupling_layers = self._create_coupling_layers(
140+
layer_settings, coupling_settings, coupling_design, num_coupling_layers
141+
)
140142
self.soft_flow = use_soft_flow
141143
self.soft_low = soft_flow_bounds[0]
142144
self.soft_high = soft_flow_bounds[1]
@@ -230,6 +232,34 @@ def inverse(self, z, condition, **kwargs):
230232
target = layer(target, condition, inverse=True, **kwargs)
231233
return target
232234

235+
@staticmethod
236+
def _create_coupling_layers(settings, coupling_settings, coupling_design, num_coupling_layers):
237+
"""Helper method to create a list of coupling layers. Takes care
238+
of the different options for coupling design.
239+
"""
240+
241+
if coupling_design not in InvertibleNetwork.available_designs:
242+
raise NotImplementedError("Coupling design should be one of", InvertibleNetwork.available_designs)
243+
244+
# Case affine or spline
245+
if coupling_design != "interleaved":
246+
design = coupling_design
247+
_coupling_settings = coupling_settings
248+
coupling_layers = [
249+
CouplingLayer(coupling_design=design, coupling_settings=_coupling_settings, **settings)
250+
for _ in range(num_coupling_layers)
251+
]
252+
# Case interleaved, starts with affine
253+
else:
254+
coupling_layers = []
255+
designs = (["affine", "spline"] * int(np.ceil(num_coupling_layers / 2)))[:num_coupling_layers]
256+
for design in designs:
257+
# Fail gently, if neither None, nor a dictionary with keys ("spline", "affine")
258+
_coupling_settings = None if coupling_settings is None else coupling_settings[design]
259+
layer = CouplingLayer(coupling_design=design, coupling_settings=_coupling_settings, **settings)
260+
coupling_layers.append(layer)
261+
return coupling_layers
262+
233263
@classmethod
234264
def create_config(cls, **kwargs):
235265
""" "Used to create the settings dictionary for the internal networks of the invertible

0 commit comments

Comments
 (0)