Skip to content

Commit 9d4c1a1

Browse files
committed
fix kwargs in set transformer
1 parent 2c161c6 commit 9d4c1a1

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

bayesflow/networks/transformers/mab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
6-
from bayesflow.utils import layer_kwargs
6+
from bayesflow.utils import layer_kwargs, filter_kwargs
77
from bayesflow.utils.decorators import sanitize_input_shape
88
from bayesflow.utils.serialization import serializable
99

@@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
111111
"""
112112

113113
h = self.input_projector(seq_x) + self.attention(
114-
query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs
114+
query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call)
115115
)
116116
if self.ln_pre is not None:
117117
h = self.ln_pre(h, training=training)

bayesflow/networks/transformers/set_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
148148
Output of shape (batch_size, set_size, output_dim)
149149
"""
150150
summary = self.attention_blocks(input_set, training=training)
151-
summary = self.pooling_by_attention(summary, training=training)
151+
summary = self.pooling_by_attention(summary, training=training, **kwargs)
152152
summary = self.output_projector(summary)
153153
return summary

0 commit comments

Comments
 (0)