Skip to content

Commit 2c161c6

Browse files
committed
fix kwargs in set transformer
1 parent 09df093 commit 2c161c6

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

bayesflow/networks/transformers/isab.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
107107
batch_size = keras.ops.shape(input_set)[0]
108108
inducing_points_expanded = keras.ops.expand_dims(self.inducing_points, axis=0)
109109
inducing_points_tiled = keras.ops.tile(inducing_points_expanded, [batch_size, 1, 1])
110+
print(kwargs)
110111
h = self.mab0(inducing_points_tiled, input_set, training=training, **kwargs)
111112
return self.mab1(input_set, h, training=training, **kwargs)

bayesflow/networks/transformers/mab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
6-
from bayesflow.utils import layer_kwargs, filter_kwargs
6+
from bayesflow.utils import layer_kwargs
77
from bayesflow.utils.decorators import sanitize_input_shape
88
from bayesflow.utils.serialization import serializable
99

@@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
111111
"""
112112

113113
h = self.input_projector(seq_x) + self.attention(
114-
query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call)
114+
query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs
115115
)
116116
if self.ln_pre is not None:
117117
h = self.ln_pre(h, training=training)

bayesflow/networks/transformers/set_transformer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import keras
22

33
from bayesflow.types import Tensor
4-
from bayesflow.utils import check_lengths_same, filter_kwargs
4+
from bayesflow.utils import check_lengths_same
55
from bayesflow.utils.serialization import serializable
66

77
from ..summary_network import SummaryNetwork
@@ -147,11 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
147147
out : Tensor
148148
Output of shape (batch_size, set_size, output_dim)
149149
"""
150-
summary = self.attention_blocks(
151-
input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call)
152-
)
153-
summary = self.pooling_by_attention(
154-
summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call)
155-
)
150+
summary = self.attention_blocks(input_set, training=training)
151+
summary = self.pooling_by_attention(summary, training=training)
156152
summary = self.output_projector(summary)
157153
return summary

0 commit comments

Comments
 (0)