Skip to content

Commit 1474bb6

Browse files
committed
Add docs for SAB / ISAB [skip ci]
1 parent 69e2387 commit 1474bb6

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

bayesflow/networks/transformers/isab.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,28 @@ def __init__(
3434
3535
Parameters
3636
----------
37-
#TODO
37+
num_inducing_points : int, optional
38+
The number of inducing points for set-based dimensionality reduction.
39+
embed_dim : int, optional
40+
Dimensionality of the embedding space, by default 64.
41+
num_heads : int, optional
42+
Number of attention heads, by default 4.
43+
dropout : float, optional
44+
Dropout rate applied to attention and MLP layers, by default 0.05.
45+
mlp_depth : int, optional
46+
Number of layers in the feedforward MLP block, by default 2.
47+
mlp_width : int, optional
48+
Width of each hidden layer in the MLP block, by default 128.
49+
mlp_activation : str, optional
50+
Activation function used in the MLP block, by default "gelu".
51+
kernel_initializer : str, optional
52+
Initializer for kernel weights, by default "he_normal".
53+
use_bias : bool, optional
54+
Whether to include bias terms in dense layers, by default True.
55+
layer_norm : bool, optional
56+
Whether to apply layer normalization before and after attention, by default True.
57+
**kwargs : dict
58+
Additional keyword arguments passed to the Keras Layer base class.
3859
"""
3960

4061
super().__init__(**kwargs)

bayesflow/networks/transformers/mab.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,26 @@ def __init__(
3333
3434
Parameters
3535
----------
36-
##TODO
36+
embed_dim : int, optional
37+
Dimensionality of the embedding space, by default 64.
38+
num_heads : int, optional
39+
Number of attention heads, by default 4.
40+
dropout : float, optional
41+
Dropout rate applied to attention and MLP layers, by default 0.05.
42+
mlp_depth : int, optional
43+
Number of layers in the feedforward MLP block, by default 2.
44+
mlp_width : int, optional
45+
Width of each hidden layer in the MLP block, by default 128.
46+
mlp_activation : str, optional
47+
Activation function used in the MLP block, by default "gelu".
48+
kernel_initializer : str, optional
49+
Initializer for kernel weights, by default "he_normal".
50+
use_bias : bool, optional
51+
Whether to include bias terms in dense layers, by default True.
52+
layer_norm : bool, optional
53+
Whether to apply layer normalization before and after attention, by default True.
54+
**kwargs : dict
55+
Additional keyword arguments passed to the Keras Layer base class.
3756
"""
3857

3958
super().__init__(**kwargs)

0 commit comments

Comments
 (0)