Skip to content

Commit ca9e245

Browse files
committed
Fix types for subnets
1 parent c230c8e commit ca9e245

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class DiffusionModel(InferenceNetwork):
5151
def __init__(
5252
self,
5353
*,
54-
subnet: str | type = "mlp",
54+
subnet: str | type | keras.Layer = "mlp",
5555
noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm",
5656
prediction_type: Literal["velocity", "noise", "F", "x"] = "F",
5757
loss_type: Literal["velocity", "noise", "F"] = "noise",
@@ -69,9 +69,9 @@ def __init__(
6969
7070
Parameters
7171
----------
72-
subnet : str or type, optional
73-
Architecture for the transformation network. Can be "mlp" or a custom network class.
74-
Default is "mlp".
72+
subnet : str, type or keras.Layer, optional
73+
Architecture for the transformation network. Can be "mlp", a custom network class, or
74+
a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp".
7575
noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional
7676
Noise schedule controlling the diffusion dynamics. Can be a string identifier,
7777
a schedule class, or a pre-initialized schedule instance. Default is "edm".

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class CouplingFlow(InferenceNetwork):
4040

4141
def __init__(
4242
self,
43-
subnet: str | type = "mlp",
43+
subnet: str | type | keras.Layer = "mlp",
4444
depth: int = 6,
4545
transform: str = "affine",
4646
permutation: str | None = "random",
@@ -68,9 +68,9 @@ def __init__(
6868
6969
Parameters
7070
----------
71-
subnet : str or type, optional
72-
The architecture type used within the coupling layers. Can be a string
73-
identifier like "mlp" or a callable type. Default is "mlp".
71+
subnet : str, type, or keras.Layer, optional
72+
Architecture for the transformation network. Can be "mlp", a custom network class, or
73+
a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp".
7474
depth : int, optional
7575
The number of invertible layers in the model. Default is 6.
7676
transform : str, optional

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class FlowMatching(InferenceNetwork):
5151

5252
def __init__(
5353
self,
54-
subnet: str | keras.Layer = "mlp",
54+
subnet: str | type | keras.Layer = "mlp",
5555
base_distribution: str | Distribution = "normal",
5656
use_optimal_transport: bool = False,
5757
loss_fn: str | keras.Loss = "mse",
@@ -74,8 +74,8 @@ def __init__(
7474
Parameters
7575
----------
7676
subnet : str or keras.Layer, optional
77-
The architecture used for the transformation network. Can be "mlp" or a custom
78-
callable network. Default is "mlp".
77+
Architecture for the transformation network. Can be "mlp", a custom network class, or
78+
a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp".
7979
base_distribution : str, optional
8080
The base probability distribution from which samples are drawn, such as "normal".
8181
Default is "normal".

0 commit comments

Comments
 (0)