diff --git a/.gitignore b/.gitignore index 1ca9eaef6..e5e86a509 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,8 @@ docsrc/source/contributing.md examples/checkpoints/ build docs/ +.venv/ +.env # mypy diff --git a/bayesflow/networks/deep_set/equivariant_layer.py b/bayesflow/networks/deep_set/equivariant_layer.py index 7e35ad9bb..95b940789 100644 --- a/bayesflow/networks/deep_set/equivariant_layer.py +++ b/bayesflow/networks/deep_set/equivariant_layer.py @@ -99,8 +99,8 @@ def __init__( self.layer_norm = layers.LayerNormalization() if layer_norm else None @sanitize_input_shape - def build(self, input_shape): - self.call(keras.ops.zeros(input_shape)) + def build(self, input_set_shape): + self.call(keras.ops.zeros(input_set_shape)) def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: """Performs the forward pass of a learnable equivariant transform. diff --git a/bayesflow/networks/deep_set/invariant_layer.py b/bayesflow/networks/deep_set/invariant_layer.py index 2f29c6b8d..8a7a3a479 100644 --- a/bayesflow/networks/deep_set/invariant_layer.py +++ b/bayesflow/networks/deep_set/invariant_layer.py @@ -115,5 +115,5 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: return set_summary @sanitize_input_shape - def build(self, input_shape): - self.call(keras.ops.zeros(input_shape)) + def build(self, input_set_shape): + self.call(keras.ops.zeros(input_set_shape))