Skip to content

Commit 0e3383b

Browse files
committed
add documentation
1 parent dc4ee7b commit 0e3383b

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def __init__(
8585
Additional keyword arguments passed to the noise schedule constructor. Default is None.
8686
integrate_kwargs : dict[str, any], optional
8787
Configuration dictionary for integration during training or inference. Default is None.
88+
concatenate_subnet_input: bool, optional
89+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
90+
into a single vector or passed as separate arguments. If set to False, the subnet
91+
must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
92+
and optional 'conditions'. Default is True.
93+
8894
**kwargs
8995
Additional keyword arguments passed to the base class and internal components.
9096
"""
@@ -227,7 +233,7 @@ def _apply_subnet(
227233
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
228234
return self.subnet(xtc, training=training)
229235
else:
230-
return self.subnet(xz, log_snr, conditions, training=training)
236+
return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training)
231237

232238
def velocity(
233239
self,

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def __init__(
6767
Final number of discretization steps
6868
subnet_kwargs: dict[str, any], optional
6969
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
70+
concatenate_subnet_input: bool, optional
71+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
72+
into a single vector or passed as separate arguments. If set to False, the subnet
73+
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
74+
and optional 'conditions'. Default is True.
7075
**kwargs : dict, optional, default: {}
7176
Additional keyword arguments
7277
"""
@@ -268,7 +273,7 @@ def _apply_subnet(
268273
Parameters
269274
----------
270275
x : Tensor
271-
The input tensor for the diffusion model, typically of shape (..., D), but can vary.
276+
The parameter tensor, typically of shape (..., D), but can vary.
272277
t : Tensor
273278
The time tensor, typically of shape (..., 1).
274279
conditions : Tensor, optional
@@ -285,7 +290,7 @@ def _apply_subnet(
285290
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
286291
return self.subnet(xtc, training=training)
287292
else:
288-
return self.subnet(x, t, conditions, training=training)
293+
return self.subnet(x=x, t=t, conditions=conditions, training=training)
289294

290295
def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
291296
"""Compute consistency function.

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def __init__(
9191
Additional keyword arguments for configuring optimal transport. Default is None.
9292
subnet_kwargs: dict[str, any], optional, deprecated
9393
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
94+
concatenate_subnet_input: bool, optional
95+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
96+
into a single vector or passed as separate arguments. If set to False, the subnet
97+
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
98+
and optional 'conditions'. Default is True.
9499
**kwargs
95100
Additional keyword arguments passed to the subnet and other components.
96101
"""
@@ -165,7 +170,7 @@ def _apply_subnet(
165170
Parameters
166171
----------
167172
x : Tensor
168-
The input tensor for the diffusion model, typically of shape (..., D), but can vary.
173+
The parameter tensor, typically of shape (..., D), but can vary.
169174
t : Tensor
170175
The time tensor, typically of shape (..., 1).
171176
conditions : Tensor, optional
@@ -185,7 +190,7 @@ def _apply_subnet(
185190
else:
186191
if training is False:
187192
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
188-
return self.subnet(x, t, conditions, training=training)
193+
return self.subnet(x=x, t=t, conditions=conditions, training=training)
189194

190195
def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
191196
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))

0 commit comments

Comments
 (0)