Skip to content

Commit 81fe388

Browse files
committed
Add doc for PMA
1 parent bef91de commit 81fe388

File tree

1 file changed

+27
-5
lines changed
  • bayesflow/networks/transformers

1 file changed

+27
-5
lines changed

bayesflow/networks/transformers/pma.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,36 @@ def __init__(
3535
layer_norm: bool = True,
3636
**kwargs,
3737
):
38-
"""Creates a multi-head attention block (MAB) which will perform cross-attention between an input sequence
39-
and a set of seed vectors (typically one for a single summary) with summary_dim output dimensions.
40-
41-
Could also be used as part of a ``DeepSet`` for representing learnable instead of fixed pooling.
38+
"""
39+
Creates a PoolingByMultiHeadAttention (PMA) block for permutation-invariant set encoding using
40+
multi-head attention pooling. Can also be used us a building block for `DeepSet` architectures.
4241
4342
Parameters
4443
----------
45-
##TODO
44+
num_seeds : int, optional (default=1)
45+
Number of seed vectors used for pooling. Acts as the number of summary outputs.
46+
embed_dim : int, optional (default=64)
47+
Dimensionality of the embedding space used in the attention mechanism.
48+
num_heads : int, optional (default=4)
49+
Number of attention heads in the multi-head attention block.
50+
seed_dim : int or None, optional (default=None)
51+
Dimensionality of each seed vector. If None, defaults to `embed_dim`.
52+
dropout : float, optional (default=0.05)
53+
Dropout rate applied to attention and MLP layers.
54+
mlp_depth : int, optional (default=2)
55+
Number of layers in the feedforward MLP applied before attention.
56+
mlp_width : int, optional (default=128)
57+
Number of units in each hidden layer of the MLP.
58+
mlp_activation : str, optional (default="gelu")
59+
Activation function used in the MLP.
60+
kernel_initializer : str, optional (default="he_normal")
61+
Initializer for kernel weights in dense layers.
62+
use_bias : bool, optional (default=True)
63+
Whether to include bias terms in dense layers.
64+
layer_norm : bool, optional (default=True)
65+
Whether to apply layer normalization before and after attention.
66+
**kwargs : dict
67+
Additional keyword arguments passed to the Keras Layer base class.
4668
"""
4769

4870
super().__init__(**kwargs)

0 commit comments

Comments
 (0)