Skip to content

Commit 967c4e5

Browse files
committed
Bring back MultiHeadAttention after keras bugfix
1 parent 0c1c9bc commit 967c4e5

File tree

2 files changed

+2
-156
lines changed

2 files changed

+2
-156
lines changed

bayesflow/networks/transformers/mab.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from bayesflow.types import Tensor
66
from bayesflow.networks import MLP
7-
from .mha import MultiHeadAttention
87

98

109
@serializable(package="bayesflow.networks")
@@ -40,8 +39,8 @@ def __init__(
4039
super().__init__(**kwargs)
4140

4241
self.input_projector = layers.Dense(embed_dim)
43-
self.attention = MultiHeadAttention(
44-
embed_dim=embed_dim,
42+
self.attention = layers.MultiHeadAttention(
43+
key_dim=embed_dim,
4544
num_heads=num_heads,
4645
dropout=dropout,
4746
use_bias=use_bias,

bayesflow/networks/transformers/mha.py

Lines changed: 0 additions & 153 deletions
This file was deleted.

0 commit comments

Comments
 (0)