Skip to content

Commit 79d727a

Browse files
committed
Improve coupling flow interface and tests
1 parent f8a68b8 commit 79d727a

File tree

10 files changed

+89
-33
lines changed

10 files changed

+89
-33
lines changed

bayesflow/networks/coupling_flow/actnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def build(self, xz_shape: Shape, **kwargs):
3434
self.scale = self.add_weight(shape=(xz_shape[-1],), initializer="ones", name="scale")
3535
self.bias = self.add_weight(shape=(xz_shape[-1],), initializer="zeros", name="bias")
3636

37-
def call(self, xz: Tensor, inverse: bool = False, **kwargs):
37+
def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> (Tensor, Tensor):
3838
if inverse:
3939
return self._inverse(xz, **kwargs)
4040
return self._forward(xz, **kwargs)

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def __init__(
4646
permutation: str | None = "random",
4747
use_actnorm: bool = True,
4848
base_distribution: str = "normal",
49+
subnet_kwargs: dict[str, any] = None,
50+
transform_kwargs: dict[str, any] = None,
4951
**kwargs,
5052
):
5153
"""
@@ -82,9 +84,15 @@ def __init__(
8284
base_distribution : str, optional
8385
The base probability distribution from which samples are drawn, such as
8486
"normal". Default is "normal".
87+
subnet_kwargs : dict of str to any, optional
88+
Keyword arguments forwarded to the subnet (e.g., MLP) constructor within
89+
each coupling layer, such as hidden sizes or activation choices.
90+
transform_kwargs : dict of str to any, optional
91+
Keyword arguments forwarded to the affine or spline transforms
92+
(e.g., bins for splines)
8593
**kwargs
86-
Additional keyword arguments passed to the ActNorm, permutation, and
87-
coupling layers for customization.
94+
Additional keyword arguments passed to `InvertibleLayer`.
95+
8896
"""
8997
super().__init__(base_distribution=base_distribution, **kwargs)
9098

@@ -97,12 +105,18 @@ def __init__(
97105
self.invertible_layers = []
98106
for i in range(depth):
99107
if use_actnorm:
100-
self.invertible_layers.append(ActNorm(**kwargs.get("actnorm_kwargs", {})))
108+
self.invertible_layers.append(ActNorm())
101109

102-
if (p := find_permutation(permutation, **kwargs.get("permutation_kwargs", {}))) is not None:
110+
if (p := find_permutation(permutation)) is not None:
103111
self.invertible_layers.append(p)
104112

105-
self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs.get("coupling_kwargs", {})))
113+
self.invertible_layers.append(
114+
DualCoupling(subnet, transform, subnet_kwargs=subnet_kwargs, transform_kwargs=transform_kwargs)
115+
)
116+
117+
# We only need to do this from coupling flows, since we do not serialize invertible layers
118+
self.subnet_kwargs = subnet_kwargs
119+
self.transform_kwargs = transform_kwargs
106120

107121
# noinspection PyMethodOverriding
108122
def build(self, xz_shape, conditions_shape=None):
@@ -126,6 +140,8 @@ def get_config(self):
126140
"permutation": self.permutation,
127141
"use_actnorm": self.use_actnorm,
128142
"base_distribution": self.base_distribution,
143+
"subnet_kwargs": self.subnet_kwargs,
144+
"transform_kwargs": self.transform_kwargs,
129145
}
130146

131147
return base_config | serialize(config)

bayesflow/networks/coupling_flow/couplings/single_coupling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,20 @@ def __init__(
2929
self,
3030
subnet: str | type = "mlp",
3131
transform: str = "affine",
32+
subnet_kwargs: dict[str, any] = None,
33+
transform_kwargs: dict[str, any] = None,
3234
**kwargs,
3335
):
3436
super().__init__(**kwargs)
3537

36-
subnet_kwargs = kwargs.get("subnet_kwargs", {})
38+
subnet_kwargs = subnet_kwargs or {}
39+
transform_kwargs = transform_kwargs or {}
40+
3741
if subnet == "mlp":
3842
subnet_kwargs = SingleCoupling.MLP_DEFAULT_CONFIG | subnet_kwargs
3943

4044
self.subnet = find_network(subnet, **subnet_kwargs)
41-
self.transform = find_transform(transform, **kwargs.get("transform_kwargs", {}))
45+
self.transform = find_transform(transform, **transform_kwargs)
4246

4347
self.output_projector = keras.layers.Dense(
4448
units=None, kernel_initializer="zeros", bias_initializer="zeros", name="output_projector"

bayesflow/networks/coupling_flow/permutations/fixed_permutation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Shape, Tensor
4+
from bayesflow.utils.serialization import serializable
5+
56
from ..invertible_layer import InvertibleLayer
67

78

8-
@serializable(package="networks.coupling_flow")
9+
@serializable
910
class FixedPermutation(InvertibleLayer):
10-
def __init__(self, forward_indices=None, inverse_indices=None, **kwargs):
11+
"""
12+
Interface class for permutations with no learnable parameters. Child classes should
13+
create forward and inverse indices in the associated build() method.
14+
"""
15+
16+
def __init__(self, **kwargs):
1117
super().__init__(**kwargs)
12-
self.forward_indices = forward_indices
13-
self.inverse_indices = inverse_indices
18+
self.forward_indices = None
19+
self.inverse_indices = None
1420

1521
def call(self, xz: Tensor, inverse: bool = False, **kwargs):
1622
if inverse:

bayesflow/networks/coupling_flow/permutations/orthogonal.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from keras import ops
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Shape, Tensor
4+
from bayesflow.utils.serialization import serializable
5+
56
from ..invertible_layer import InvertibleLayer
67

78

8-
@serializable(package="networks.coupling_flow")
9+
@serializable
910
class OrthogonalPermutation(InvertibleLayer):
1011
"""Implements a learnable orthogonal transformation according to [1]. Can be
1112
used as an alternative to a fixed ``Permutation`` layer.
@@ -21,7 +22,7 @@ def __init__(self, **kwargs):
2122
def build(self, xz_shape: Shape, **kwargs) -> None:
2223
self.weight = self.add_weight(shape=(xz_shape[-1], xz_shape[-1]), initializer="orthogonal", trainable=True)
2324

24-
def call(self, xz: Tensor, inverse: bool = False, **kwargs):
25+
def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> (Tensor, Tensor):
2526
if inverse:
2627
return self._inverse(xz)
2728
return self._forward(xz)

bayesflow/networks/coupling_flow/permutations/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Shape
4+
from bayesflow.utils.serialization import serializable
5+
56
from .fixed_permutation import FixedPermutation
67

78

8-
@serializable(package="networks.coupling_flow")
9+
@serializable
910
class RandomPermutation(FixedPermutation):
1011
# noinspection PyMethodOverriding
1112
def build(self, xz_shape: Shape, **kwargs) -> None:

bayesflow/networks/coupling_flow/permutations/swap.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Shape
4+
from bayesflow.utils.serialization import serializable
5+
56
from .fixed_permutation import FixedPermutation
67

78

8-
@serializable(package="networks.coupling_flow")
9+
@serializable
910
class Swap(FixedPermutation):
1011
def build(self, xz_shape: Shape, **kwargs) -> None:
1112
shift = xz_shape[-1] // 2

bayesflow/networks/coupling_flow/transforms/affine_transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import keras.ops as ops
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Tensor
54
from bayesflow.utils.keras_utils import shifted_softplus
5+
from bayesflow.utils.serialization import serializable
6+
67
from .transform import Transform
78

89

9-
@serializable(package="networks.coupling_flow")
10+
@serializable
1011
class AffineTransform(Transform):
1112
def __init__(self, clamp: bool = True, **kwargs):
1213
super().__init__(**kwargs)

bayesflow/networks/coupling_flow/transforms/spline_transform.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import keras
22
import numpy as np
3-
from keras.saving import (
4-
register_keras_serializable as serializable,
5-
)
63

74
from bayesflow.types import Tensor
85
from bayesflow.utils import pad, searchsorted
96
from bayesflow.utils.keras_utils import shifted_softplus
7+
from bayesflow.utils.serialization import serializable
8+
109
from ._rational_quadratic import _rational_quadratic_spline
1110
from .transform import Transform
1211

1312

14-
@serializable(package="networks.coupling_flow")
13+
@serializable
1514
class SplineTransform(Transform):
1615
def __init__(
1716
self,

tests/test_networks/conftest.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,41 @@ def flow_matching():
88
from bayesflow.networks import FlowMatching
99

1010
return FlowMatching(
11-
subnet=MLP([64, 64]),
11+
subnet=MLP([8, 8]),
1212
integrate_kwargs={"method": "rk45", "steps": 100},
1313
)
1414

1515

1616
@pytest.fixture()
17-
def coupling_flow():
17+
def consistency_model():
18+
from bayesflow.networks import ConsistencyModel
19+
20+
return ConsistencyModel(total_steps=100, subnet=MLP([8, 8]))
21+
22+
23+
@pytest.fixture()
24+
def affine_coupling_flow():
25+
from bayesflow.networks import CouplingFlow
26+
27+
return CouplingFlow(
28+
depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="affine", transform_kwargs=dict(clamp=1.8)
29+
)
30+
31+
32+
@pytest.fixture()
33+
def spline_coupling_flow():
1834
from bayesflow.networks import CouplingFlow
1935

20-
return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=[64, 64]))
36+
return CouplingFlow(
37+
depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="spline", transform_kwargs=dict(bins=8)
38+
)
2139

2240

2341
@pytest.fixture()
2442
def free_form_flow():
2543
from bayesflow.experimental import FreeFormFlow
2644

27-
return FreeFormFlow(encoder_subnet=MLP([64, 64]), decoder_subnet=MLP([64, 64]))
45+
return FreeFormFlow(encoder_subnet=MLP([16, 16]), decoder_subnet=MLP([16, 16]))
2846

2947

3048
@pytest.fixture()
@@ -47,7 +65,7 @@ def typical_point_inference_network_subnet():
4765
from bayesflow.networks import PointInferenceNetwork
4866
from bayesflow.scores import MeanScore, MedianScore, QuantileScore, MultivariateNormalScore
4967

50-
subnet = MLP([64, 64])
68+
subnet = MLP([16, 8])
5169

5270
return PointInferenceNetwork(
5371
scores=dict(
@@ -61,7 +79,14 @@ def typical_point_inference_network_subnet():
6179

6280

6381
@pytest.fixture(
64-
params=["typical_point_inference_network", "coupling_flow", "flow_matching", "free_form_flow"], scope="function"
82+
params=[
83+
"typical_point_inference_network",
84+
"affine_coupling_flow",
85+
"spline_coupling_flow",
86+
"flow_matching",
87+
"free_form_flow",
88+
],
89+
scope="function",
6590
)
6691
def inference_network(request):
6792
return request.getfixturevalue(request.param)
@@ -80,7 +105,9 @@ def inference_network_subnet(request):
80105
return request.getfixturevalue(request.param)
81106

82107

83-
@pytest.fixture(params=["coupling_flow", "flow_matching", "free_form_flow"], scope="function")
108+
@pytest.fixture(
109+
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow"], scope="function"
110+
)
84111
def generative_inference_network(request):
85112
return request.getfixturevalue(request.param)
86113

0 commit comments

Comments
 (0)