Skip to content

Commit de30009

Browse files
authored
Enable use of summary networks with functional API again (#434)
* summary networks: add tests for using functional API * fix build functions for use with functional API
1 parent 688f22c commit de30009

File tree

7 files changed

+38
-2
lines changed

7 files changed

+38
-2
lines changed

bayesflow/links/ordered.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.utils import layer_kwargs
5+
from bayesflow.utils.decorators import sanitize_input_shape
56

67

78
@serializable(package="links.ordered")
@@ -49,5 +50,6 @@ def call(self, inputs):
4950
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
5051
return x
5152

53+
@sanitize_input_shape
5254
def compute_output_shape(self, input_shape):
5355
return input_shape

bayesflow/networks/summary_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def build(self, input_shape):
2121
if self.base_distribution is not None:
2222
self.base_distribution.build(keras.ops.shape(z))
2323

24+
@sanitize_input_shape
2425
def compute_output_shape(self, input_shape):
2526
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))
2627

bayesflow/networks/transformers/mab.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
66
from bayesflow.utils import layer_kwargs
7+
from bayesflow.utils.decorators import sanitize_input_shape
78
from bayesflow.utils.serialization import serializable
89

910

@@ -122,8 +123,10 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
122123
return out
123124

124125
# noinspection PyMethodOverriding
126+
@sanitize_input_shape
125127
def build(self, seq_x_shape, seq_y_shape):
126128
self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape))
127129

130+
@sanitize_input_shape
128131
def compute_output_shape(self, seq_x_shape, seq_y_shape):
129132
return keras.ops.shape(self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape)))

bayesflow/networks/transformers/pma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
66
from bayesflow.utils import layer_kwargs
7+
from bayesflow.utils.decorators import sanitize_input_shape
78
from bayesflow.utils.serialization import serializable
89

910
from .mab import MultiHeadAttentionBlock
@@ -125,5 +126,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
125126
summaries = self.mab(seed_tiled, set_x_transformed, training=training, **kwargs)
126127
return ops.reshape(summaries, (ops.shape(summaries)[0], -1))
127128

129+
@sanitize_input_shape
128130
def compute_output_shape(self, input_shape):
129131
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))

bayesflow/networks/transformers/sab.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import keras
22

33
from bayesflow.types import Tensor
4+
from bayesflow.utils.decorators import sanitize_input_shape
45
from bayesflow.utils.serialization import serializable
56

67
from .mab import MultiHeadAttentionBlock
@@ -16,6 +17,7 @@ class SetAttentionBlock(MultiHeadAttentionBlock):
1617
"""
1718

1819
# noinspection PyMethodOverriding
20+
@sanitize_input_shape
1921
def build(self, input_set_shape):
2022
self.call(keras.ops.zeros(input_set_shape))
2123

@@ -42,5 +44,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
4244
return super().call(input_set, input_set, training=training, **kwargs)
4345

4446
# noinspection PyMethodOverriding
47+
@sanitize_input_shape
4548
def compute_output_shape(self, input_set_shape):
4649
return keras.ops.shape(self.call(keras.ops.zeros(input_set_shape)))

bayesflow/utils/decorators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def callback(x):
114114

115115

116116
def sanitize_input_shape(fn: Callable):
117-
"""Decorator to replace the first dimension in input_shape with a dummy batch size if it is None"""
117+
"""Decorator to replace the first dimension in ..._shape arguments with a dummy batch size if it is None"""
118118

119119
# The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which
120120
# causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used
@@ -126,5 +126,8 @@ def callback(input_shape: Shape) -> Shape:
126126
return tuple(input_shape)
127127
return input_shape
128128

129-
fn = argument_callback("input_shape", callback)(fn)
129+
args = inspect.getfullargspec(fn).args
130+
for arg in args:
131+
if arg.endswith("_shape"):
132+
fn = argument_callback(arg, callback)(fn)
130133
return fn

tests/test_networks/test_summary_networks.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,28 @@ def test_build(automatic, summary_network, random_set):
2525
assert summary_network.variables, "Model has no variables."
2626

2727

28+
@pytest.mark.parametrize("automatic", [True, False])
29+
def test_build_functional_api(automatic, summary_network, random_set):
30+
if summary_network is None:
31+
pytest.skip(reason="Nothing to do, because there is no summary network.")
32+
33+
assert summary_network.built is False
34+
35+
inputs = keras.layers.Input(shape=keras.ops.shape(random_set)[1:])
36+
outputs = summary_network(inputs)
37+
model = keras.Model(inputs=inputs, outputs=outputs)
38+
39+
if automatic:
40+
model(random_set)
41+
else:
42+
model.build(keras.ops.shape(random_set))
43+
44+
assert model.built is True
45+
46+
# check the model has variables
47+
assert summary_network.variables, "Model has no variables."
48+
49+
2850
def test_variable_batch_size(summary_network, random_set):
2951
if summary_network is None:
3052
pytest.skip(reason="Nothing to do, because there is no summary network.")

0 commit comments

Comments
 (0)