Skip to content

Commit 79b4fbb

Browse files
committed
Some syntactic sugar fixes
1 parent 22e15d2 commit 79b4fbb

File tree

5 files changed

+26
-17
lines changed

5 files changed

+26
-17
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@ def __init__(
4545
eps: float = 0.001,
4646
s0: int | float = 10,
4747
s1: int | float = 50,
48+
subnet_kwargs: dict[str, any] = None,
4849
**kwargs,
4950
):
5051
"""Creates an instance of a consistency model (CM) to be used for standalone consistency training (CT).
5152
5253
Parameters
5354
----------
5455
total_steps : int
55-
The total number of training steps, can be calculate as
56-
number of epochs * number of batches
56+
The total number of training steps, must be calculated as number of epochs * number of batches
57+
and cannot be inferred during construction time.
5758
subnet : str or type, optional, default: "mlp"
5859
A neural network type for the consistency model, will be
5960
instantiated using subnet_kwargs.
@@ -67,21 +68,20 @@ def __init__(
6768
Initial number of discretization steps
6869
s1 : int or float, optional, default: 50
6970
Final number of discretization steps
71+
subnet_kwargs: dict[str, any], optional
72+
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
7073
**kwargs : dict, optional, default: {}
7174
Additional keyword arguments
7275
"""
7376
super().__init__(base_distribution="normal", **keras_kwargs(kwargs))
7477

7578
self.total_steps = float(total_steps)
7679

80+
subnet_kwargs = subnet_kwargs or {}
7781
if subnet == "mlp":
78-
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG.copy()
79-
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
80-
else:
81-
subnet_kwargs = kwargs.get("subnet_kwargs", {})
82+
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
8283

8384
self.student = find_network(subnet, **subnet_kwargs)
84-
8585
self.student_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
8686

8787
self.sigma2 = ops.convert_to_tensor(sigma2)
@@ -108,6 +108,7 @@ def __init__(
108108
"eps": eps,
109109
"s0": s0,
110110
"s1": s1,
111+
"subnet_kwargs": subnet_kwargs,
111112
**kwargs,
112113
}
113114
self.config = serialize_value_or_type(self.config, "subnet", subnet)

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class CouplingFlow(InferenceNetwork):
3636

3737
def __init__(
3838
self,
39-
depth: int = 6,
4039
subnet: str | type = "mlp",
40+
depth: int = 6,
4141
transform: str = "affine",
4242
permutation: str | None = "random",
4343
use_actnorm: bool = True,
@@ -62,11 +62,11 @@ def __init__(
6262
6363
Parameters
6464
----------
65-
depth : int, optional
66-
The number of invertible layers in the model. Default is 6.
6765
subnet : str or type, optional
6866
The architecture type used within the coupling layers. Can be a string
6967
identifier like "mlp" or a callable type. Default is "mlp".
68+
depth : int, optional
69+
The number of invertible layers in the model. Default is 6.
7070
transform : str, optional
7171
The type of transformation used in the coupling layers, such as "affine".
7272
Default is "affine".

bayesflow/networks/coupling_flow/couplings/single_coupling.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@ class SingleCoupling(InvertibleLayer):
2727
def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwargs):
2828
super().__init__(**keras_kwargs(kwargs))
2929

30+
subnet_kwargs = kwargs.get("subnet_kwargs", {})
3031
if subnet == "mlp":
31-
subnet_kwargs = SingleCoupling.MLP_DEFAULT_CONFIG.copy()
32-
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
33-
else:
34-
subnet_kwargs = kwargs.get("subnet_kwargs", {})
32+
subnet_kwargs = SingleCoupling.MLP_DEFAULT_CONFIG | subnet_kwargs
3533

3634
self.network = find_network(subnet, **subnet_kwargs)
3735
self.transform = find_transform(transform, **kwargs.get("transform_kwargs", {}))

bayesflow/networks/deep_set/equivariant_module.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,19 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
104104
105105
Parameters
106106
----------
107-
#TODO
107+
input_set : Tensor
108+
The input tensor representing a set, with shape
109+
(batch_size, ..., set_size, input_dim).
110+
training : bool, optional
111+
A flag indicating whether the model is in training mode. Default is False.
112+
**kwargs : dict
113+
Additional keyword arguments for compatibility with other functions.
108114
109115
Returns
110116
-------
111-
#TODO
117+
Tensor
118+
The transformed output tensor with the same shape as `input_set`, where
119+
each element is processed through the equivariant transformation.
112120
"""
113121

114122
input_set = self.input_projector(input_set)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(
8787
Additional keyword arguments for the integration process. Default is None.
8888
optimal_transport_kwargs : dict[str, any], optional
8989
Additional keyword arguments for configuring optimal transport. Default is None.
90+
subnet_kwargs: dict[str, any], optional
91+
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
9092
**kwargs
9193
Additional keyword arguments passed to the subnet and other components.
9294
"""
@@ -103,7 +105,6 @@ def __init__(
103105
self.seed_generator = keras.random.SeedGenerator()
104106

105107
subnet_kwargs = subnet_kwargs or {}
106-
107108
if subnet == "mlp":
108109
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs
109110

@@ -116,6 +117,7 @@ def __init__(
116117
"use_optimal_transport": self.use_optimal_transport,
117118
"optimal_transport_kwargs": self.optimal_transport_kwargs,
118119
"integrate_kwargs": self.integrate_kwargs,
120+
"subnet_kwargs": subnet_kwargs,
119121
**kwargs,
120122
}
121123
self.config = serialize_value_or_type(self.config, "subnet", subnet)

0 commit comments

Comments
 (0)