Skip to content

Commit 628227a

Browse files
committed
Fix Keras build/call parameter naming in DeepSet layers
1 parent 6c97e76 commit 628227a

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ docsrc/source/contributing.md
1313
examples/checkpoints/
1414
build
1515
docs/
16+
.venv/
17+
.env
1618

1719

1820
# mypy

bayesflow/networks/deep_set/equivariant_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def __init__(
9999
self.layer_norm = layers.LayerNormalization() if layer_norm else None
100100

101101
@sanitize_input_shape
102-
def build(self, input_shape):
103-
self.call(keras.ops.zeros(input_shape))
102+
def build(self, input_set_shape):
103+
self.call(keras.ops.zeros(input_set_shape))
104104

105105
def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
106106
"""Performs the forward pass of a learnable equivariant transform.

bayesflow/networks/deep_set/invariant_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,5 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
115115
return set_summary
116116

117117
@sanitize_input_shape
118-
def build(self, input_shape):
119-
self.call(keras.ops.zeros(input_shape))
118+
def build(self, input_set_shape):
119+
self.call(keras.ops.zeros(input_set_shape))

0 commit comments

Comments
 (0)