Skip to content

Commit 2c6d91e

Browse files
committed
fix type info
1 parent eac0371 commit 2c6d91e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ def convert_prediction_to_x(
199199
return (z + sigma_t**2 * pred) / alpha_t
200200
raise ValueError(f"Unknown prediction type {self._prediction_type}.")
201201

202-
def prepare_subnet_input(self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None) -> Tensor:
202+
def prepare_subnet_input(
203+
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None
204+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
203205
"""
204206
Prepares the input for the subnet either by concatenating the latent variable `xz`,
205207
the log signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.

0 commit comments

Comments
 (0)