Skip to content

Commit e0f8b4e

Browse files
authored
Merge pull request #392 from bayesflow-org/remove-width-and-depth-args-in-mlp
remove width and depth argument in MLP
2 parents 512b323 + fbd9a8b commit e0f8b4e

File tree

3 files changed

+5
-25
lines changed

3 files changed

+5
-25
lines changed

bayesflow/experimental/cif/conditional_gaussian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(self, depth: int = 4, width: int = 128, activation: str = "swish",
3333
"""
3434

3535
super().__init__(**keras_kwargs(kwargs))
36-
self.means = MLP(depth=depth, width=width, activation=activation)
37-
self.stds = MLP(depth=depth, width=width, activation=activation)
36+
self.means = MLP([width] * depth, activation=activation)
37+
self.stds = MLP([width] * depth, activation=activation)
3838
self.output_projector = keras.layers.Dense(None)
3939

4040
def build(self, input_shape: Shape) -> None:

bayesflow/networks/mlp/mlp.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@ class MLP(keras.Layer):
2020

2121
def __init__(
2222
self,
23+
widths: Sequence[int] = (256, 256),
2324
*,
24-
depth: int = None,
25-
width: int = None,
26-
widths: Sequence[int] = None,
2725
activation: str = "mish",
2826
kernel_initializer: str = "he_normal",
2927
residual: bool = False,
@@ -46,15 +44,8 @@ def __init__(
4644
4745
Parameters
4846
----------
49-
depth : int, optional
50-
Number of layers in the MLP when `widths` is not explicitly provided. Must be
51-
used together with `width`. Default is 2.
52-
width : int, optional
53-
Number of units per layer when `widths` is not explicitly provided. Must be used
54-
together with `depth`. Default is 256.
5547
widths : Sequence[int], optional
56-
Explicitly defines the number of hidden units per layer. If provided, `depth` and
57-
`width` should not be specified. Default is None.
48+
Defines the number of hidden units per layer, as well as the number of layers to be used.
5849
activation : str, optional
5950
Activation function applied in the hidden layers, such as "mish". Default is "mish".
6051
kernel_initializer : str, optional
@@ -76,17 +67,6 @@ def __init__(
7667

7768
super().__init__(**keras_kwargs(kwargs))
7869

79-
if widths is not None:
80-
if depth is not None or width is not None:
81-
raise ValueError("Either specify 'widths' or 'depth' and 'width', not both.")
82-
else:
83-
if depth is None or width is None:
84-
# use the default
85-
depth = 2
86-
width = 256
87-
88-
widths = [width] * depth
89-
9070
self.res_blocks = []
9171
for width in widths:
9272
self.res_blocks.append(

tests/test_networks/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def flow_matching():
2222
from bayesflow.networks import FlowMatching
2323

2424
return FlowMatching(
25-
subnet_kwargs={"widths": None, "width": 64, "depth": 2},
25+
subnet_kwargs={"widths": [64, 64]},
2626
integrate_kwargs={"method": "rk45", "steps": 100},
2727
)
2828

0 commit comments

Comments
 (0)