Skip to content

Commit 07ec251

Browse files
authored
feat: allow use of functional API by sanitizing build input shapes (#332)
1 parent ec2a0e2 commit 07ec251

File tree

7 files changed

+30
-0
lines changed

7 files changed

+30
-0
lines changed

bayesflow/networks/deep_set/deep_set.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from bayesflow.types import Tensor
77
from bayesflow.utils import filter_kwargs
8+
from bayesflow.utils.decorators import sanitize_input_shape
89

910
from .equivariant_module import EquivariantModule
1011
from .invariant_module import InvariantModule
@@ -78,6 +79,7 @@ def __init__(
7879
self.output_projector = keras.layers.Dense(summary_dim, activation="linear")
7980
self.summary_dim = summary_dim
8081

82+
@sanitize_input_shape
8183
def build(self, input_shape):
8284
super().build(input_shape)
8385
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/deep_set/equivariant_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.saving import register_keras_serializable as serializable
66

77
from bayesflow.types import Tensor
8+
from bayesflow.utils.decorators import sanitize_input_shape
89
from .invariant_module import InvariantModule
910

1011

@@ -66,6 +67,7 @@ def __init__(
6667

6768
self.layer_norm = layers.LayerNormalization() if layer_norm else None
6869

70+
@sanitize_input_shape
6971
def build(self, input_shape):
7072
self.call(keras.ops.zeros(input_shape))
7173

bayesflow/networks/deep_set/invariant_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from bayesflow.types import Tensor
88
from bayesflow.utils import find_pooling
9+
from bayesflow.utils.decorators import sanitize_input_shape
910

1011

1112
@serializable(package="bayesflow.networks")
@@ -76,6 +77,7 @@ def __init__(
7677

7778
self.pooling_layer = find_pooling(pooling, **pooling_kwargs)
7879

80+
@sanitize_input_shape
7981
def build(self, input_shape):
8082
self.call(keras.ops.zeros(input_shape))
8183

bayesflow/networks/lstnet/lstnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.types import Tensor
5+
from bayesflow.utils.decorators import sanitize_input_shape
56
from .skip_recurrent import SkipRecurrentNet
67
from ..summary_network import SummaryNetwork
78

@@ -78,6 +79,7 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
7879
x = self.output_projector(x)
7980
return x
8081

82+
@sanitize_input_shape
8183
def build(self, input_shape):
8284
super().build(input_shape)
8385
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/lstnet/skip_recurrent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from bayesflow.types import Tensor
55
from bayesflow.utils import keras_kwargs, find_recurrent_net
6+
from bayesflow.utils.decorators import sanitize_input_shape
67

78

89
@serializable(package="bayesflow.networks")
@@ -58,5 +59,6 @@ def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
5859
skip_summary = self.skip_recurrent(self.skip_conv(time_series), training=training)
5960
return keras.ops.concatenate((direct_summary, skip_summary), axis=-1)
6061

62+
@sanitize_input_shape
6163
def build(self, input_shape):
6264
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/summary_network.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from bayesflow.metrics.functional import maximum_mean_discrepancy
44
from bayesflow.types import Tensor
55
from bayesflow.utils import find_distribution, keras_kwargs
6+
from bayesflow.utils.decorators import sanitize_input_shape
67

78

89
class SummaryNetwork(keras.Layer):
910
def __init__(self, base_distribution: str = None, **kwargs):
1011
super().__init__(**keras_kwargs(kwargs))
1112
self.base_distribution = find_distribution(base_distribution)
1213

14+
@sanitize_input_shape
1315
def build(self, input_shape):
1416
if self.base_distribution is not None:
1517
output_shape = keras.ops.shape(self.call(keras.ops.zeros(input_shape)))

bayesflow/utils/decorators.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import wraps
33
import inspect
44
from typing import overload, TypeVar
5+
from bayesflow.types import Shape
56

67
Fn = TypeVar("Fn", bound=Callable[..., any])
78

@@ -110,3 +111,20 @@ def callback(x):
110111
fn = alias("batch_shape", "batch_size")(fn)
111112

112113
return fn
114+
115+
116+
def sanitize_input_shape(fn: Callable):
117+
"""Decorator to replace the first dimension in input_shape with a dummy batch size if it is None"""
118+
119+
# The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which
120+
# causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used
121+
# in build. To alleviate those problems, this decorator replaces None with an arbitrary batch size.
122+
def callback(input_shape: Shape) -> Shape:
123+
if input_shape[0] is None:
124+
input_shape = list(input_shape)
125+
input_shape[0] = 32
126+
return tuple(input_shape)
127+
return input_shape
128+
129+
fn = argument_callback("input_shape", callback)(fn)
130+
return fn

0 commit comments

Comments
 (0)