diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py index 47be02317..77545b6f8 100644 --- a/bayesflow/links/ordered.py +++ b/bayesflow/links/ordered.py @@ -2,6 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.utils import layer_kwargs +from bayesflow.utils.decorators import sanitize_input_shape @serializable(package="links.ordered") @@ -49,5 +50,6 @@ def call(self, inputs): x = keras.ops.concatenate([below, anchor_input, above], self.axis) return x + @sanitize_input_shape def compute_output_shape(self, input_shape): return input_shape diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 316df39e6..6e97c618f 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -21,6 +21,7 @@ def build(self, input_shape): if self.base_distribution is not None: self.base_distribution.build(keras.ops.shape(z)) + @sanitize_input_shape def compute_output_shape(self, input_shape): return keras.ops.shape(self.call(keras.ops.zeros(input_shape))) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index a2e22da16..8f0e3f881 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -4,6 +4,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs +from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -122,8 +123,10 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - return out # noinspection PyMethodOverriding + @sanitize_input_shape def build(self, seq_x_shape, seq_y_shape): self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape)) + @sanitize_input_shape def compute_output_shape(self, seq_x_shape, seq_y_shape): return keras.ops.shape(self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape))) diff --git a/bayesflow/networks/transformers/pma.py b/bayesflow/networks/transformers/pma.py index 5eb6a269d..956c85b48 100644 --- a/bayesflow/networks/transformers/pma.py +++ b/bayesflow/networks/transformers/pma.py @@ -4,6 +4,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs +from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable from .mab import MultiHeadAttentionBlock @@ -125,5 +126,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: summaries = self.mab(seed_tiled, set_x_transformed, training=training, **kwargs) return ops.reshape(summaries, (ops.shape(summaries)[0], -1)) + @sanitize_input_shape def compute_output_shape(self, input_shape): return keras.ops.shape(self.call(keras.ops.zeros(input_shape))) diff --git a/bayesflow/networks/transformers/sab.py b/bayesflow/networks/transformers/sab.py index a69dc5fa4..a447d92a2 100644 --- a/bayesflow/networks/transformers/sab.py +++ b/bayesflow/networks/transformers/sab.py @@ -1,6 +1,7 @@ import keras from bayesflow.types import Tensor +from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable from .mab import MultiHeadAttentionBlock @@ -16,6 +17,7 @@ class SetAttentionBlock(MultiHeadAttentionBlock): """ # noinspection PyMethodOverriding + @sanitize_input_shape def build(self, input_set_shape): self.call(keras.ops.zeros(input_set_shape)) @@ -42,5 +44,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: return super().call(input_set, input_set, training=training, **kwargs) # noinspection PyMethodOverriding + @sanitize_input_shape def compute_output_shape(self, input_set_shape): return keras.ops.shape(self.call(keras.ops.zeros(input_set_shape))) diff --git a/bayesflow/utils/decorators.py b/bayesflow/utils/decorators.py index 91afc9fb7..7fd32edc9 100644 --- a/bayesflow/utils/decorators.py +++ b/bayesflow/utils/decorators.py @@ -114,7 +114,7 @@ def callback(x): def sanitize_input_shape(fn: Callable): - """Decorator to replace the first dimension in input_shape with a dummy batch size if it is None""" + """Decorator to replace the first dimension in ..._shape arguments with a dummy batch size if it is None""" # The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which # causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used @@ -126,5 +126,8 @@ def callback(input_shape: Shape) -> Shape: return tuple(input_shape) return input_shape - fn = argument_callback("input_shape", callback)(fn) + args = inspect.getfullargspec(fn).args + for arg in args: + if arg.endswith("_shape"): + fn = argument_callback(arg, callback)(fn) return fn diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 082ce4d25..50e1726c1 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -25,6 +25,28 @@ def test_build(automatic, summary_network, random_set): assert summary_network.variables, "Model has no variables." +@pytest.mark.parametrize("automatic", [True, False]) +def test_build_functional_api(automatic, summary_network, random_set): + if summary_network is None: + pytest.skip(reason="Nothing to do, because there is no summary network.") + + assert summary_network.built is False + + inputs = keras.layers.Input(shape=keras.ops.shape(random_set)[1:]) + outputs = summary_network(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + if automatic: + model(random_set) + else: + model.build(keras.ops.shape(random_set)) + + assert model.built is True + + # check the model has variables + assert summary_network.variables, "Model has no variables." + + def test_variable_batch_size(summary_network, random_set): if summary_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.")