Skip to content

Commit 9d9a73c

Browse files
committed
fix docstring
1 parent 3ea64b7 commit 9d9a73c

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _subnet_input(
204204
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
205205
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
206206
"""
207-
Prepares the input for the subnet either by concatenating the latent variable `xz`,
207+
Prepares and passes the input to the subnet either by concatenating the latent variable `xz`,
208208
the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.
209209
210210
Parameters
@@ -221,7 +221,7 @@ def _subnet_input(
221221
Returns
222222
-------
223223
Tensor
224-
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
224+
The output tensor from the subnet.
225225
"""
226226
if self._concatenate_subnet_input:
227227
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _subnet_input(
262262
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
263263
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
264264
"""
265-
Prepares the input for the subnet either by concatenating the latent variable `x`,
265+
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
266266
the time `t`, and optional conditions or by returning them separately.
267267
268268
Parameters
@@ -279,7 +279,7 @@ def _subnet_input(
279279
Returns
280280
-------
281281
Tensor
282-
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
282+
The output tensor from the subnet.
283283
"""
284284
if self._concatenate_subnet_input:
285285
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _subnet_input(
159159
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
160160
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
161161
"""
162-
Prepares the input for the subnet either by concatenating the latent variable `x`,
162+
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
163163
the time `t`, and optional conditions or by returning them separately.
164164
165165
Parameters
@@ -176,7 +176,7 @@ def _subnet_input(
176176
Returns
177177
-------
178178
Tensor
179-
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
179+
The output tensor from the subnet.
180180
"""
181181
if self._concatenate_subnet_input:
182182
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))

0 commit comments

Comments
 (0)