Skip to content

Commit fc5b705

Browse files
committed
Major refactor: model to layer
1 parent 8b6769d commit fc5b705

27 files changed

+74
-151
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import warnings
77

88
from bayesflow.types import Tensor
9-
from bayesflow.utils import find_network, model_kwargs, weighted_mean
9+
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
1010
from bayesflow.utils.serialization import deserialize, serializable, serialize
1111

1212
from ..inference_network import InferenceNetwork
@@ -118,7 +118,7 @@ def from_config(cls, config, custom_objects=None):
118118

119119
def get_config(self):
120120
base_config = super().get_config()
121-
base_config = model_kwargs(base_config)
121+
base_config = layer_kwargs(base_config)
122122

123123
config = {
124124
"total_steps": self.total_steps,
@@ -128,6 +128,7 @@ def get_config(self):
128128
"eps": self.eps,
129129
"s0": self.s0,
130130
"s1": self.s1,
131+
# we do not need to store subnet_kwargs
131132
}
132133

133134
return base_config | serialize(config)

bayesflow/networks/coupling_flow/actnorm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from keras import ops
22

3-
from keras.saving import register_keras_serializable as serializable
4-
53
from bayesflow.types import Shape, Tensor
4+
from bayesflow.utils.serialization import serializable
65

76
from .invertible_layer import InvertibleLayer
87

98

10-
@serializable(package="networks.coupling_flow")
9+
@serializable
1110
class ActNorm(InvertibleLayer):
1211
"""Implements an Activation Normalization (ActNorm) Layer. Activation Normalization is learned invertible
1312
normalization, using a scale (s) and a bias (b) vector::

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from bayesflow.types import Tensor
44
from bayesflow.utils import (
55
find_permutation,
6-
model_kwargs,
6+
layer_kwargs,
77
weighted_mean,
88
)
99
from bayesflow.utils.serialization import deserialize, serializable, serialize
@@ -131,7 +131,7 @@ def from_config(cls, config, custom_objects=None):
131131

132132
def get_config(self):
133133
base_config = super().get_config()
134-
base_config = model_kwargs(base_config)
134+
base_config = layer_kwargs(base_config)
135135

136136
config = {
137137
"subnet": self.subnet,

bayesflow/networks/coupling_flow/couplings/dual_coupling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from bayesflow.utils import model_kwargs
44
from bayesflow.utils.serialization import deserialize, serializable, serialize
55
from bayesflow.types import Tensor
6+
67
from .single_coupling import SingleCoupling
8+
79
from ..invertible_layer import InvertibleLayer
810

911

bayesflow/networks/deep_set/deep_set.py

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import keras
44

55
from bayesflow.types import Tensor
6-
from bayesflow.utils import model_kwargs
7-
from bayesflow.utils.serialization import deserialize, serializable, serialize
6+
from bayesflow.utils.serialization import serializable
7+
8+
from .equivariant_layer import EquivariantLayer
9+
from .invariant_layer import InvariantLayer
810

9-
from .equivariant_module import EquivariantModule
10-
from .invariant_module import InvariantModule
1111
from ..summary_network import SummaryNetwork
1212

1313

@@ -88,7 +88,7 @@ def __init__(
8888
# Stack of equivariant modules for a many-to-many learnable transformation
8989
self.equivariant_modules = []
9090
for _ in range(depth):
91-
equivariant_module = EquivariantModule(
91+
equivariant_module = EquivariantLayer(
9292
mlp_widths_equivariant=mlp_widths_equivariant,
9393
mlp_widths_invariant_inner=mlp_widths_invariant_inner,
9494
mlp_widths_invariant_outer=mlp_widths_invariant_outer,
@@ -97,37 +97,24 @@ def __init__(
9797
spectral_normalization=spectral_normalization,
9898
dropout=dropout,
9999
pooling=inner_pooling,
100-
name="equivariant_module",
101100
)
102101
self.equivariant_modules.append(equivariant_module)
103102

104103
# Invariant module for a many-to-one transformation
105-
self.invariant_module = InvariantModule(
104+
self.invariant_module = InvariantLayer(
106105
mlp_widths_inner=mlp_widths_invariant_last,
107106
mlp_widths_outer=mlp_widths_invariant_last,
108107
activation=activation,
109108
kernel_initializer=kernel_initializer,
110109
dropout=dropout,
111110
pooling=output_pooling,
112111
spectral_normalization=spectral_normalization,
113-
name="invariant_module",
114112
)
115113

116114
# Output linear layer to project set representation down to "summary_dim" learned summary statistics
117115
self.output_projector = keras.layers.Dense(summary_dim, activation="linear", name="output_projector")
118116

119117
self.summary_dim = summary_dim
120-
self.depth = depth
121-
self.inner_pooling = inner_pooling
122-
self.output_pooling = output_pooling
123-
self.mlp_widths_equivariant = mlp_widths_equivariant
124-
self.mlp_widths_invariant_inner = mlp_widths_invariant_inner
125-
self.mlp_widths_invariant_outer = mlp_widths_invariant_outer
126-
self.mlp_widths_invariant_last = mlp_widths_invariant_last
127-
self.activation = activation
128-
self.kernel_initializer = kernel_initializer
129-
self.dropout = dropout
130-
self.spectral_normalization = spectral_normalization
131118

132119
def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
133120
"""
@@ -161,28 +148,3 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
161148
x = self.invariant_module(x, training=training)
162149

163150
return self.output_projector(x)
164-
165-
@classmethod
166-
def from_config(cls, config, custom_objects=None):
167-
return cls(**deserialize(config, custom_objects=custom_objects))
168-
169-
def get_config(self):
170-
base_config = super().get_config()
171-
base_config = model_kwargs(base_config)
172-
173-
config = {
174-
"summary_dim": self.summary_dim,
175-
"depth": self.depth,
176-
"inner_pooling": self.inner_pooling,
177-
"output_pooling": self.output_pooling,
178-
"mlp_widths_equivariant": self.mlp_widths_equivariant,
179-
"mlp_widths_invariant_inner": self.mlp_widths_invariant_inner,
180-
"mlp_widths_invariant_outer": self.mlp_widths_invariant_outer,
181-
"mlp_widths_invariant_last": self.mlp_widths_invariant_last,
182-
"activation": self.activation,
183-
"kernel_initializer": self.kernel_initializer,
184-
"dropout": self.dropout,
185-
"spectral_normalization": self.spectral_normalization,
186-
}
187-
188-
return base_config | serialize(config)

bayesflow/networks/deep_set/equivariant_module.py renamed to bayesflow/networks/deep_set/equivariant_layer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
from keras import ops, layers
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import model_kwargs
7+
from bayesflow.utils import layer_kwargs
88
from bayesflow.utils.decorators import sanitize_input_shape
99
from bayesflow.utils.serialization import serializable
1010

1111
from ..mlp import MLP
1212

13-
from .invariant_module import InvariantModule
13+
from .invariant_layer import InvariantLayer
1414

1515

1616
@serializable
17-
class EquivariantModule(keras.Model):
17+
class EquivariantLayer(keras.Layer):
1818
"""Implements an equivariant module performing an equivariant transform.
1919
2020
For details and justification, see:
@@ -72,32 +72,30 @@ def __init__(
7272
Whether to apply spectral normalization to stabilize training. Default is False.
7373
"""
7474

75-
super().__init__(**model_kwargs(kwargs))
75+
super().__init__(**layer_kwargs(kwargs))
7676

7777
# Invariant module to increase expressiveness by concatenating outputs to each set member
78-
self.invariant_module = InvariantModule(
78+
self.invariant_module = InvariantLayer(
7979
mlp_widths_inner=mlp_widths_invariant_inner,
8080
mlp_widths_outer=mlp_widths_invariant_outer,
8181
activation=activation,
8282
kernel_initializer=kernel_initializer,
8383
dropout=dropout,
8484
pooling=pooling,
8585
spectral_normalization=spectral_normalization,
86-
name="invariant_module",
8786
)
8887

8988
# Fully connected net + residual connection for an equivariant transform applied to each set member
90-
self.input_projector = layers.Dense(mlp_widths_equivariant[-1], name="input_projector")
89+
self.input_projector = layers.Dense(mlp_widths_equivariant[-1])
9190
self.equivariant_fc = MLP(
9291
mlp_widths_equivariant,
9392
dropout=dropout,
9493
activation=activation,
9594
kernel_initializer=kernel_initializer,
9695
spectral_normalization=spectral_normalization,
97-
name="equivariant_fc",
9896
)
9997

100-
self.layer_norm = layers.LayerNormalization(name="layer_norm") if layer_norm else None
98+
self.layer_norm = layers.LayerNormalization() if layer_norm else None
10199

102100
@sanitize_input_shape
103101
def build(self, input_shape):

bayesflow/networks/deep_set/invariant_module.py renamed to bayesflow/networks/deep_set/invariant_layer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import keras
44

55
from bayesflow.types import Tensor
6-
from bayesflow.utils import model_kwargs
6+
from bayesflow.utils import layer_kwargs
77
from bayesflow.utils import find_pooling
88
from bayesflow.utils.decorators import sanitize_input_shape
99
from bayesflow.utils.serialization import serializable
@@ -12,7 +12,7 @@
1212

1313

1414
@serializable
15-
class InvariantModule(keras.Model):
15+
class InvariantLayer(keras.Layer):
1616
"""Implements an invariant module performing a permutation-invariant transform.
1717
1818
For details and rationale, see:
@@ -64,7 +64,7 @@ def __init__(
6464
Whether to apply spectral normalization to stabilize training. Default is False.
6565
"""
6666

67-
super().__init__(**model_kwargs(kwargs))
67+
super().__init__(**layer_kwargs(kwargs))
6868

6969
# Inner fully connected net for sum decomposition: inner( pooling( inner(set) ) )
7070
self.inner_fc = MLP(
@@ -88,11 +88,6 @@ def __init__(
8888
pooling_kwargs = {}
8989

9090
self.pooling_layer = find_pooling(pooling, **pooling_kwargs)
91-
self.pooling_layer.name = f"{pooling}_pooling"
92-
93-
@sanitize_input_shape
94-
def build(self, input_shape):
95-
self.call(keras.ops.zeros(input_shape))
9691

9792
def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
9893
"""Performs the forward pass of a learnable invariant transform.
@@ -114,3 +109,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
114109
set_summary = self.pooling_layer(set_summary, training=training)
115110
set_summary = self.outer_fc(set_summary, training=training)
116111
return set_summary
112+
113+
@sanitize_input_shape
114+
def build(self, input_shape):
115+
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/embeddings/fourier_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import keras
44
from keras import ops
5-
from keras.saving import register_keras_serializable as serializable
65

76
from bayesflow.types import Tensor
7+
from bayesflow.utils.serialization import serializable
88

99

10-
@serializable(package="bayesflow.networks")
10+
@serializable
1111
class FourierEmbedding(keras.Layer):
1212
"""Implements a Fourier projection with normally distributed frequencies."""
1313

bayesflow/networks/embeddings/recurrent_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Tensor
54
from bayesflow.utils import expand_tile
65

6+
from bayesflow.utils.serialization import serializable
77

8-
@serializable(package="bayesflow.networks")
8+
9+
@serializable
910
class RecurrentEmbedding(keras.Layer):
1011
"""Implements a recurrent network for flexibly embedding time vectors."""
1112

bayesflow/networks/embeddings/time2vec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Tensor
54
from bayesflow.utils import expand_tile
5+
from bayesflow.utils.serialization import serializable
66

77

8-
@serializable(package="bayesflow.networks")
8+
@serializable
99
class Time2Vec(keras.Layer):
1010
"""
1111
Implements the Time2Vec learnbale embedding from [1].

0 commit comments

Comments
 (0)