@@ -35,14 +35,36 @@ def __init__(
3535 layer_norm : bool = True ,
3636 ** kwargs ,
3737 ):
38- """Creates a multi-head attention block (MAB) which will perform cross-attention between an input sequence
39- and a set of seed vectors (typically one for a single summary) with summary_dim output dimensions.
40-
41- Could also be used as part of a ``DeepSet`` for representing learnable instead of fixed pooling.
38+ """
39+ Creates a PoolingByMultiHeadAttention (PMA) block for permutation-invariant set encoding using
40+ multi-head attention pooling. Can also be used us a building block for `DeepSet` architectures.
4241
4342 Parameters
4443 ----------
45- ##TODO
44+ num_seeds : int, optional (default=1)
45+ Number of seed vectors used for pooling. Acts as the number of summary outputs.
46+ embed_dim : int, optional (default=64)
47+ Dimensionality of the embedding space used in the attention mechanism.
48+ num_heads : int, optional (default=4)
49+ Number of attention heads in the multi-head attention block.
50+ seed_dim : int or None, optional (default=None)
51+ Dimensionality of each seed vector. If None, defaults to `embed_dim`.
52+ dropout : float, optional (default=0.05)
53+ Dropout rate applied to attention and MLP layers.
54+ mlp_depth : int, optional (default=2)
55+ Number of layers in the feedforward MLP applied before attention.
56+ mlp_width : int, optional (default=128)
57+ Number of units in each hidden layer of the MLP.
58+ mlp_activation : str, optional (default="gelu")
59+ Activation function used in the MLP.
60+ kernel_initializer : str, optional (default="he_normal")
61+ Initializer for kernel weights in dense layers.
62+ use_bias : bool, optional (default=True)
63+ Whether to include bias terms in dense layers.
64+ layer_norm : bool, optional (default=True)
65+ Whether to apply layer normalization before and after attention.
66+ **kwargs : dict
67+ Additional keyword arguments passed to the Keras Layer base class.
4668 """
4769
4870 super ().__init__ (** kwargs )
0 commit comments