diff --git a/bayesflow/experimental/__init__.py b/bayesflow/experimental/__init__.py index 4c6f80848..6c0b6828f 100644 --- a/bayesflow/experimental/__init__.py +++ b/bayesflow/experimental/__init__.py @@ -4,8 +4,9 @@ from .cif import CIF from .continuous_time_consistency_model import ContinuousTimeConsistencyModel +from .diffusion_model import DiffusionModel from .free_form_flow import FreeFormFlow from ..utils._docs import _add_imports_to_all -_add_imports_to_all(include_modules=[]) +_add_imports_to_all(include_modules=["diffusion_model"]) diff --git a/bayesflow/experimental/diffusion_model/__init__.py b/bayesflow/experimental/diffusion_model/__init__.py new file mode 100644 index 000000000..8c5bd247f --- /dev/null +++ b/bayesflow/experimental/diffusion_model/__init__.py @@ -0,0 +1,9 @@ +from .diffusion_model import DiffusionModel +from .noise_schedule import NoiseSchedule +from .cosine_noise_schedule import CosineNoiseSchedule +from .edm_noise_schedule import EDMNoiseSchedule +from .dispatch import find_noise_schedule + +from ...utils._docs import _add_imports_to_all + +_add_imports_to_all(include_modules=[]) diff --git a/bayesflow/experimental/diffusion_model/cosine_noise_schedule.py b/bayesflow/experimental/diffusion_model/cosine_noise_schedule.py new file mode 100644 index 000000000..6aab71646 --- /dev/null +++ b/bayesflow/experimental/diffusion_model/cosine_noise_schedule.py @@ -0,0 +1,85 @@ +import math +from typing import Union, Literal + +from keras import ops + +from bayesflow.types import Tensor +from bayesflow.utils.serialization import deserialize, serializable + +from .noise_schedule import NoiseSchedule + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class CosineNoiseSchedule(NoiseSchedule): + """Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1]. + + [1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022) + """ + + def __init__( + self, + min_log_snr: float = -15, + max_log_snr: float = 15, + shift: float = 0.0, + weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid", + ): + """ + Initialize the cosine noise schedule. + + Parameters + ---------- + min_log_snr : float, optional + The minimum log signal-to-noise ratio (lambda). Default is -15. + max_log_snr : float, optional + The maximum log signal-to-noise ratio (lambda). Default is 15. + shift : float, optional + Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0. + For images, use shift = log(base_resolution / d), where d is the used resolution of the image. + weighting : Literal["sigmoid", "likelihood_weighting"], optional + The type of weighting function to use for the noise schedule. Default is "sigmoid". + """ + super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting) + self._shift = shift + self._weighting = weighting + self.log_snr_min = min_log_snr + self.log_snr_max = max_log_snr + + self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True) + self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True) + + def _truncated_t(self, t: Tensor) -> Tensor: + return self._t_min + (self._t_max - self._t_min) * t + + def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor: + """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" + t_trunc = self._truncated_t(t) + return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift + + def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor: + """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" + # SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2)) + return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5)) + + def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor: + """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE.""" + t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training) + + # Compute the truncated time t_trunc + t_trunc = self._truncated_t(t) + dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc) + + # Using the chain rule on f(t) = log(1 + e^(-snr(t))): + # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt + dsnr_dt = dsnr_dx * (self._t_max - self._t_min) + factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t)) + return -factor * dsnr_dt + + def get_config(self): + return dict( + min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift, weighting=self._weighting + ) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) diff --git a/bayesflow/experimental/diffusion_model/diffusion_model.py b/bayesflow/experimental/diffusion_model/diffusion_model.py new file mode 100644 index 000000000..c5e1154c6 --- /dev/null +++ b/bayesflow/experimental/diffusion_model/diffusion_model.py @@ -0,0 +1,437 @@ +from collections.abc import Sequence +from typing import Literal + +import keras +from keras import ops + +from bayesflow.networks import InferenceNetwork +from bayesflow.types import Tensor, Shape +from bayesflow.utils import ( + expand_right_as, + find_network, + jacobian_trace, + layer_kwargs, + weighted_mean, + integrate, + integrate_stochastic, + logging, + tensor_utils, +) +from .dispatch import find_noise_schedule +from bayesflow.utils.serialization import serialize, deserialize, serializable + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class DiffusionModel(InferenceNetwork): + """Diffusion Model as described in this overview paper [1]. + + [1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data + Augmentation: Kingma et al. (2023) + + [2] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021) + """ + + MLP_DEFAULT_CONFIG = { + "widths": (256, 256, 256, 256, 256), + "activation": "mish", + "kernel_initializer": "he_normal", + "residual": True, + "dropout": 0.0, + "spectral_normalization": False, + } + + INTEGRATE_DEFAULT_CONFIG = { + "method": "euler", # or euler_maruyama + "steps": 250, + } + + def __init__( + self, + *, + subnet: str | type = "mlp", + integrate_kwargs: dict[str, any] = None, + noise_schedule: Literal["edm", "cosine"] | dict | type = "edm", + prediction_type: Literal["velocity", "noise", "F"] = "F", + **kwargs, + ): + """ + Initializes a diffusion model with configurable subnet architecture. + + This model learns a transformation from a Gaussian latent distribution to a target distribution using a + specified subnet type, which can be an MLP or a custom network. + + The integration can be customized with additional parameters available in the integrate_kwargs + configuration dictionary. Different noise schedules and prediction types are available. + + Parameters + ---------- + subnet : str or type, optional + The architecture used for the transformation network. Can be "mlp" or a custom + callable network. Default is "mlp". + integrate_kwargs : dict[str, any], optional + Additional keyword arguments for the integration process. Default is None. + noise_schedule : Literal['edm', 'cosine'], dict or type, optional + The noise schedule used for the diffusion process. Can be "cosine" or "edm" or a custom noise schedule. + You can also pass a dictionary with the configuration for the noise schedule, e.g., + {'name': cosine, 's_shift_cosine': 1.0} + Default is "edm". + prediction_type: Literal['velocity', 'noise', 'F'], optional + The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM). + Default is "F". + **kwargs + Additional keyword arguments passed to the subnet and other components. + """ + super().__init__(base_distribution="normal", **kwargs) + + self.noise_schedule = find_noise_schedule(noise_schedule) + self.noise_schedule.validate() + + if prediction_type not in ["noise", "velocity", "F"]: # F is EDM + raise TypeError(f"Unknown prediction type: {prediction_type}") + self._prediction_type = prediction_type + self._loss_type = kwargs.get("loss_type", "noise") + if self._loss_type not in ["noise", "velocity", "F"]: + raise TypeError(f"Unknown loss type: {self._loss_type}") + if self._loss_type != "noise": + logging.warning( + "the standard schedules have weighting functions defined for the noise prediction loss. " + "You might want to replace them, if you use a different loss function." + ) + + # clipping of prediction (after it was transformed to x-prediction) + # keeping this private for now, as it is usually not required in SBI and somewhat dangerous + self._clip_x = kwargs.get("clip_x", None) + if self._clip_x is not None: + if len(self._clip_x) != 2 or self._clip_x[0] > self._clip_x[1]: + raise ValueError("'clip_x' has to be a list or tuple with the values [x_min, x_max]") + + self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {}) + self.seed_generator = keras.random.SeedGenerator() + + if subnet == "mlp": + self.subnet = find_network(subnet, **self.MLP_DEFAULT_CONFIG) + else: + self.subnet = find_network(subnet) + self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros") + + def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: + if self.built: + return + + self.base_distribution.build(xz_shape) + + self.output_projector.units = xz_shape[-1] + input_shape = list(xz_shape) + + # construct time vector + input_shape[-1] += 1 + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + + input_shape = tuple(input_shape) + + self.subnet.build(input_shape) + out_shape = self.subnet.compute_output_shape(input_shape) + self.output_projector.build(out_shape) + + def get_config(self): + base_config = super().get_config() + base_config = layer_kwargs(base_config) + + config = { + "subnet": self.subnet, + "noise_schedule": self.noise_schedule, + "integrate_kwargs": self.integrate_kwargs, + "prediction_type": self._prediction_type, + "loss_type": self._loss_type, + } + return base_config | serialize(config) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + def convert_prediction_to_x( + self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor + ) -> Tensor: + """Convert the prediction of the neural network to the x space.""" + if self._prediction_type == "velocity": + # convert v into x + x = alpha_t * z - sigma_t * pred + elif self._prediction_type == "noise": + # convert noise prediction into x + x = (z - sigma_t * pred) / alpha_t + elif self._prediction_type == "F": # EDM + sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0 + x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2) + x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) + x = x1 * z + x2 * pred + elif self._prediction_type == "x": + x = pred + else: # "score" + x = (z + sigma_t**2 * pred) / alpha_t + + if self._clip_x is not None: + x = ops.clip(x, self._clip_x[0], self._clip_x[1]) + return x + + def velocity( + self, + xz: Tensor, + time: float | Tensor, + stochastic_solver: bool, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + # calculate the current noise level and transform into correct shape + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t, training=training) + + if conditions is None: + xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1) + else: + xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1) + pred = self.output_projector(self.subnet(xtc, training=training), training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + # convert x to score + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + + # compute velocity f, g of the SDE or ODE + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz) + + if stochastic_solver: + # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW + out = f - g_squared * score + else: + # for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt + out = f - 0.5 * g_squared * score + + return out + + def compute_diffusion_term( + self, + xz: Tensor, + time: float | Tensor, + training: bool = False, + ) -> Tensor: + # calculate the current noise level and transform into correct shape + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t) + return ops.sqrt(g_squared) + + def _velocity_trace( + self, + xz: Tensor, + time: Tensor, + conditions: Tensor = None, + max_steps: int = None, + training: bool = False, + ) -> (Tensor, Tensor): + def f(x): + return self.velocity(x, time=time, stochastic_solver=False, conditions=conditions, training=training) + + v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True) + + return v, ops.expand_dims(trace, axis=-1) + + def _transform_log_snr(self, log_snr: Tensor) -> Tensor: + """Transform the log_snr to the range [-1, 1] for the diffusion process.""" + log_snr_min = self.noise_schedule.log_snr_min + log_snr_max = self.noise_schedule.log_snr_max + + # Calculate normalized value within the range [0, 1] + normalized_snr = (log_snr - log_snr_min) / (log_snr_max - log_snr_min) + + # Scale to [-1, 1] range + scaled_value = 2 * normalized_snr - 1 + return scaled_value + + def _forward( + self, + x: Tensor, + conditions: Tensor = None, + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for forward integration.") + + if density: + + def deltas(time, xz): + v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) + return {"xz": v, "trace": trace} + + state = { + "xz": x, + "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), + } + state = integrate( + deltas, + state, + **integrate_kwargs, + ) + + z = state["xz"] + log_density = self.base_distribution.log_prob(z) + ops.squeeze(state["trace"], axis=-1) + + return z, log_density + + def deltas(time, xz): + return { + "xz": self.velocity(xz, time=time, stochastic_solver=False, conditions=conditions, training=training) + } + + state = {"xz": x} + state = integrate( + deltas, + state, + **integrate_kwargs, + ) + z = state["xz"] + return z + + def _inverse( + self, + z: Tensor, + conditions: Tensor = None, + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + if density: + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for density computation.") + + def deltas(time, xz): + v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training) + return {"xz": v, "trace": trace} + + state = { + "xz": z, + "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), + } + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + log_density = self.base_distribution.log_prob(z) - ops.squeeze(state["trace"], axis=-1) + + return x, log_density + + state = {"xz": z} + if integrate_kwargs["method"] == "euler_maruyama": + + def deltas(time, xz): + return { + "xz": self.velocity(xz, time=time, stochastic_solver=True, conditions=conditions, training=training) + } + + def diffusion(time, xz): + return {"xz": self.compute_diffusion_term(xz, time=time, training=training)} + + state = integrate_stochastic( + drift_fn=deltas, + diffusion_fn=diffusion, + state=state, + seed=self.seed_generator, + **integrate_kwargs, + ) + else: + + def deltas(time, xz): + return { + "xz": self.velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + } + + state = integrate( + deltas, + state, + **integrate_kwargs, + ) + + x = state["xz"] + return x + + def compute_metrics( + self, + x: Tensor | Sequence[Tensor, ...], + conditions: Tensor = None, + sample_weight: Tensor = None, + stage: str = "training", + ) -> dict[str, Tensor]: + training = stage == "training" + # use same noise schedule for training and validation to keep them comparable + noise_schedule_training_stage = stage == "training" or stage == "validation" + if not self.built: + xz_shape = ops.shape(x) + conditions_shape = None if conditions is None else ops.shape(conditions) + self.build(xz_shape, conditions_shape) + + # sample training diffusion time as low discrepancy sequence to decrease variance + u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator) + i = ops.arange(0, ops.shape(x)[0], dtype=ops.dtype(x)) # tensor of indices + t = (u0 + i / ops.cast(ops.shape(x)[0], dtype=ops.dtype(x))) % 1 + + # calculate the noise level + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t, training=noise_schedule_training_stage), x) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma( + log_snr_t=log_snr_t, training=noise_schedule_training_stage + ) + weights_for_snr = self.noise_schedule.get_weights_for_snr(log_snr_t=log_snr_t) + + # generate noise vector + eps_t = keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator) + + # diffuse x + diffused_x = alpha_t * x + sigma_t * eps_t + + # calculate output of the network + if conditions is None: + xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1) + else: + xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1) + pred = self.output_projector(self.subnet(xtc, training=training), training=training) + + x_pred = self.convert_prediction_to_x( + pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t + ) + + # Calculate loss + if self._loss_type == "noise": + # convert x to epsilon prediction + noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t + loss = weights_for_snr * ops.mean((noise_pred - eps_t) ** 2, axis=-1) + elif self._loss_type == "velocity": + # convert x to velocity prediction + velocity_pred = (alpha_t * diffused_x - x_pred) / sigma_t + v_t = alpha_t * eps_t - sigma_t * x + loss = weights_for_snr * ops.mean((velocity_pred - v_t) ** 2, axis=-1) + elif self._loss_type == "F": + # convert x to F prediction + sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0 + x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data) + x2 = (sigma_data * alpha_t) / (ops.exp(-log_snr_t / 2) * ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)) + f_pred = x1 * x_pred - x2 * diffused_x + f_t = x1 * x - x2 * diffused_x + loss = weights_for_snr * ops.mean((f_pred - f_t) ** 2, axis=-1) + else: + raise ValueError(f"Unknown loss type: {self._loss_type}") + + # apply sample weight + loss = weighted_mean(loss, sample_weight) + + base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) + return base_metrics | {"loss": loss} diff --git a/bayesflow/experimental/diffusion_model/dispatch.py b/bayesflow/experimental/diffusion_model/dispatch.py new file mode 100644 index 000000000..bc02ab3f3 --- /dev/null +++ b/bayesflow/experimental/diffusion_model/dispatch.py @@ -0,0 +1,51 @@ +from functools import singledispatch +from .noise_schedule import NoiseSchedule + + +@singledispatch +def find_noise_schedule(arg, *args, **kwargs): + raise TypeError(f"Not a noise schedule: {arg!r}. Please pass an object of type 'NoiseSchedule'.") + + +@find_noise_schedule.register +def _(noise_schedule: NoiseSchedule): + return noise_schedule + + +@find_noise_schedule.register +def _(name: str, *args, **kwargs): + match name.lower(): + case "cosine": + from .cosine_noise_schedule import CosineNoiseSchedule + + return CosineNoiseSchedule() + case "edm": + from .edm_noise_schedule import EDMNoiseSchedule + + return EDMNoiseSchedule() + case other: + raise ValueError(f"Unsupported noise schedule name: '{other}'.") + + +@find_noise_schedule.register +def _(config: dict, *args, **kwargs): + name = config.get("name", "").lower() + params = {k: v for k, v in config.items() if k != "name"} + match name: + case "cosine": + from .cosine_noise_schedule import CosineNoiseSchedule + + return CosineNoiseSchedule(**params) + case "edm": + from .edm_noise_schedule import EDMNoiseSchedule + + return EDMNoiseSchedule(**params) + case other: + raise ValueError(f"Unsupported noise schedule config: '{other}'.") + + +@find_noise_schedule.register +def _(cls: type, *args, **kwargs): + if issubclass(cls, NoiseSchedule): + return cls(*args, **kwargs) + raise TypeError(f"Expected subclass of NoiseSchedule, got {cls}") diff --git a/bayesflow/experimental/diffusion_model/edm_noise_schedule.py b/bayesflow/experimental/diffusion_model/edm_noise_schedule.py new file mode 100644 index 000000000..a7973fbd0 --- /dev/null +++ b/bayesflow/experimental/diffusion_model/edm_noise_schedule.py @@ -0,0 +1,112 @@ +import math +from typing import Union + +from keras import ops + +from bayesflow.types import Tensor +from bayesflow.utils.serialization import deserialize, serializable + +from .noise_schedule import NoiseSchedule + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class EDMNoiseSchedule(NoiseSchedule): + """EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1]. + This should be used with the F-prediction type in the diffusion model. + + [1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022) + """ + + def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max: float = 80.0): + """ + Initialize the EDM noise schedule. + + Parameters + ---------- + sigma_data : float, optional + The standard deviation of the output distribution. Input of the network is scaled by this factor and + the weighting function is scaled by this factor as well. + sigma_min : float, optional + The minimum noise level. Only relevant for sampling. Default is 1e-4. + sigma_max : float, optional + The maximum noise level. Only relevant for sampling. Default is 80.0. + """ + super().__init__(name="edm_noise_schedule", variance_type="preserving") + self.sigma_data = sigma_data + # training settings + self.p_mean = -1.2 + self.p_std = 1.2 + # sampling settings + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.rho = 7 + + # convert EDM parameters to signal-to-noise ratio formulation + self.log_snr_min = -2 * ops.log(sigma_max) + self.log_snr_max = -2 * ops.log(sigma_min) + # t is not truncated for EDM by definition of the sampling schedule + # training bounds should be set to avoid numerical issues + self._log_snr_min_training = self.log_snr_min - 1 # one is never sampler during training + self._log_snr_max_training = self.log_snr_max + 1 # 0 is almost surely never sampled during training + + def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor: + """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" + if training: + # SNR = -dist.icdf(t_trunc) # negative seems to be wrong in the paper in the Kingma paper + loc = -2 * self.p_mean + scale = 2 * self.p_std + snr = loc + scale * ops.erfinv(2 * t - 1) * math.sqrt(2) + snr = ops.clip(snr, x_min=self._log_snr_min_training, x_max=self._log_snr_max_training) + else: # sampling + sigma_min_rho = self.sigma_min ** (1 / self.rho) + sigma_max_rho = self.sigma_max ** (1 / self.rho) + snr = -2 * self.rho * ops.log(sigma_max_rho + (1 - t) * (sigma_min_rho - sigma_max_rho)) + return snr + + def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor: + """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" + if training: + # SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr) # negative seems to be wrong in the Kingma paper + loc = -2 * self.p_mean + scale = 2 * self.p_std + x = log_snr_t + t = 0.5 * (1 + ops.erf((x - loc) / (scale * math.sqrt(2.0)))) + else: # sampling + # SNR = -2 * rho * log(sigma_max ** (1/rho) + (1 - t) * (sigma_min ** (1/rho) - sigma_max ** (1/rho))) + # => t = 1 - ((exp(-snr/(2*rho)) - sigma_max ** (1/rho)) / (sigma_min ** (1/rho) - sigma_max ** (1/rho))) + sigma_min_rho = self.sigma_min ** (1 / self.rho) + sigma_max_rho = self.sigma_max ** (1 / self.rho) + t = 1 - ((ops.exp(-log_snr_t / (2 * self.rho)) - sigma_max_rho) / (sigma_min_rho - sigma_max_rho)) + return t + + def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor: + """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE.""" + if training: + raise NotImplementedError("Derivative of log SNR is not implemented for training mode.") + # sampling mode + t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training) + + # SNR = -2*rho*log(s_max + (1 - x)*(s_min - s_max)) + s_max = self.sigma_max ** (1 / self.rho) + s_min = self.sigma_min ** (1 / self.rho) + u = s_max + (1 - t) * (s_min - s_max) + # d/dx snr = 2*rho*(s_min - s_max) / u + dsnr_dx = 2 * self.rho * (s_min - s_max) / u + + # Using the chain rule on f(t) = log(1 + e^(-snr(t))): + # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt + factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t)) + return -factor * dsnr_dx + + def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor: + """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).""" + # for F-prediction: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2) + return ops.exp(-log_snr_t) / ops.square(self.sigma_data) + 1 + + def get_config(self): + return dict(sigma_data=self.sigma_data, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) diff --git a/bayesflow/experimental/diffusion_model/noise_schedule.py b/bayesflow/experimental/diffusion_model/noise_schedule.py new file mode 100644 index 000000000..21ffc1ef8 --- /dev/null +++ b/bayesflow/experimental/diffusion_model/noise_schedule.py @@ -0,0 +1,163 @@ +from abc import ABC, abstractmethod +from typing import Union, Literal + +from keras import ops + +from bayesflow.types import Tensor +from bayesflow.utils.serialization import deserialize, serializable + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class NoiseSchedule(ABC): + r"""Noise schedule for diffusion models. We follow the notation from [1]. + + The diffusion process is defined by a noise schedule, which determines how the noise level changes over time. + We define the noise schedule as a function of the log signal-to-noise ratio (lambda), which can be + interchangeably used with the diffusion time (t). + + The noise process is defined as: z = alpha(t) * x + sigma(t) * e, where e ~ N(0, I). + The schedule is defined as: \lambda(t) = \log \sigma^2(t) - \log \alpha^2(t). + + We can also define a weighting function for each noise level for the loss function. Often the noise schedule is + the same for the forward and reverse process, but this is not necessary and can be changed via the training flag. + + [1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data + Augmentation: Kingma et al. (2023) + """ + + def __init__( + self, + name: str, + variance_type: Literal["preserving", "exploding"], + weighting: Literal["sigmoid", "likelihood_weighting"] = None, + ): + """ + Initialize the noise schedule. + + Parameters + ---------- + name : str + The name of the noise schedule. + variance_type : Literal["preserving", "exploding"] + If the variance of noise added to the data should be preserved over time, use "preserving". + If the variance of noise added to the data should increase over time, use "exploding". + Default is "preserving". + weighting : Literal["sigmoid", "likelihood_weighting"], optional + The type of weighting function to use for the noise schedule. + Default is None, which means no weighting is applied. + """ + self.name = name + self._variance_type = variance_type + self.log_snr_min = None # should be set in the subclasses + self.log_snr_max = None # should be set in the subclasses + self._weighting = weighting + + @abstractmethod + def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor: + """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" + pass + + @abstractmethod + def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor: + """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" + pass + + @abstractmethod + def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor: + r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE.""" + pass + + def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]: + r"""Compute the drift and optionally the squared diffusion term for the reverse SDE. + It can be derived from the derivative of the schedule: + + math:: + \beta(t) = d/dt \log(1 + e^{-snr(t)}) + + f(z, t) = -0.5 * \beta(t) * z + + g(t)^2 = \beta(t) + + The corresponding differential equations are:: + + SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW + ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt + + For a variance exploding schedule, one should set f(z, t) = 0. + """ + beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training) + if x is None: # return g^2 only + return beta + if self._variance_type == "preserving": + f = -0.5 * beta * x + elif self._variance_type == "exploding": + f = ops.zeros_like(beta) + else: + raise ValueError(f"Unknown variance type: {self._variance_type}") + return f, beta + + def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]: + """Get alpha and sigma for a given log signal-to-noise ratio (lambda). + + Default is a variance preserving schedule:: + + alpha(t) = sqrt(sigmoid(log_snr_t)) + sigma(t) = sqrt(sigmoid(-log_snr_t)) + + For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda) + """ + if self._variance_type == "preserving": + # variance preserving schedule + alpha_t = ops.sqrt(ops.sigmoid(log_snr_t)) + sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t)) + elif self._variance_type == "exploding": + # variance exploding schedule + alpha_t = ops.ones_like(log_snr_t) + sigma_t = ops.sqrt(ops.exp(-log_snr_t)) + else: + raise TypeError(f"Unknown variance type: {self._variance_type}") + return alpha_t, sigma_t + + def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor: + """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). + Default weighting is None, which means only ones are returned. + Generally, weighting functions should be defined for a noise prediction loss. + """ + if self._weighting is None: + return ops.ones_like(log_snr_t) + elif self._weighting == "sigmoid": + # sigmoid weighting based on Kingma et al. (2023) + return ops.sigmoid(-log_snr_t + 2) + elif self._weighting == "likelihood_weighting": + # likelihood weighting based on Song et al. (2021) + g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t) + sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1] + return g_squared / ops.square(sigma_t) + else: + raise TypeError(f"Unknown weighting type: {self._weighting}") + + def get_config(self): + return dict(name=self.name, variance_type=self._variance_type, weighting=self._weighting) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + def validate(self): + """Validate the noise schedule.""" + if self.log_snr_min >= self.log_snr_max: + raise ValueError("min_log_snr must be less than max_log_snr.") + for training in [True, False]: + if not ops.isfinite(self.get_log_snr(0.0, training=training)): + raise ValueError(f"log_snr(0) must be finite with training={training}.") + if not ops.isfinite(self.get_log_snr(1.0, training=training)): + raise ValueError(f"log_snr(1) must be finite with training={training}.") + if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_max, training=training)): + raise ValueError(f"t(0) must be finite with training={training}.") + if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_min, training=training)): + raise ValueError(f"t(1) must be finite with training={training}.") + if not ops.isfinite(self.derivative_log_snr(self.log_snr_max, training=False)): + raise ValueError("dt/t log_snr(0) must be finite.") + if not ops.isfinite(self.derivative_log_snr(self.log_snr_min, training=False)): + raise ValueError("dt/t log_snr(1) must be finite.") diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index b12ff823b..776e42fcd 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -33,6 +33,7 @@ find_summary_network, find_inference_network, find_distribution, + find_noise_schedule, ) from .ecdf import simultaneous_ecdf_bands, ranks @@ -46,10 +47,7 @@ ) from .hparam_utils import find_batch_size, find_memory_budget - -from .integrate import ( - integrate, -) +from .integrate import integrate, integrate_stochastic from .io import ( pickle_load, diff --git a/bayesflow/utils/dispatch/find_noise_schedule.py b/bayesflow/utils/dispatch/find_noise_schedule.py new file mode 100644 index 000000000..e69de29bb diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 5e3b407ec..3f2d7f5c0 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -4,10 +4,11 @@ import keras import numpy as np -from typing import Literal +from typing import Literal, Union from bayesflow.types import Tensor from bayesflow.utils import filter_kwargs + from . import logging ArrayLike = int | float | Tensor @@ -293,3 +294,107 @@ def integrate( return integrate_scheduled(fn, state, steps, method, **kwargs) else: raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") + + +def euler_maruyama_step( + drift_fn: Callable, + diffusion_fn: Callable, + state: dict[str, ArrayLike], + time: ArrayLike, + step_size: ArrayLike, + noise: dict[str, ArrayLike], +) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): + """ + Performs a single Euler-Maruyama step for stochastic differential equations. + + Args: + drift_fn: Function computing the drift term f(t, **state). + diffusion_fn: Function computing the diffusion term g(t, **state). + state: Current state, mapping variable names to tensors. + time: Current time scalar tensor. + step_size: Time increment dt. + noise: Mapping of variable names to dW noise tensors. + + Returns: + new_state: Updated state after one Euler-Maruyama step. + new_time: time + dt. + """ + # Compute drift and diffusion + drift = drift_fn(time, **filter_kwargs(state, drift_fn)) + diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) + + # Check noise keys + if set(diffusion.keys()) != set(noise.keys()): + raise ValueError("Keys of diffusion terms and noise do not match.") + + new_state = {} + for key, d in drift.items(): + base = state[key] + step_size * d + if key in diffusion: # stochastic update + base = base + diffusion[key] * noise[key] + new_state[key] = base + + return new_state, time + step_size + + +def integrate_stochastic( + drift_fn: Callable, + diffusion_fn: Callable, + state: dict[str, ArrayLike], + start_time: ArrayLike, + stop_time: ArrayLike, + steps: int, + seed: keras.random.SeedGenerator, + method: str = "euler_maruyama", + **kwargs, +) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: + """ + Integrates a stochastic differential equation from start_time to stop_time. + + Args: + drift_fn: Function that computes the drift term. + diffusion_fn: Function that computes the diffusion term. + state: Dictionary containing the initial state. + start_time: Starting time for integration. + stop_time: Ending time for integration. + steps: Number of integration steps. + seed: Random seed for noise generation. + method: Integration method to use, e.g., 'euler_maruyama'. + **kwargs: Additional arguments to pass to the step function. + + Returns: + If return_noise is False, returns the final state dictionary. + If return_noise is True, returns a tuple of (final_state, noise_history). + """ + if steps <= 0: + raise ValueError("Number of steps must be positive.") + + # Select step function based on method + match method: + case "euler_maruyama": + step_fn = euler_maruyama_step + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + # Prepare step function with partial application + step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs) + + # Time step + step_size = (stop_time - start_time) / steps + sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) + + # Pre-generate noise history: shape = (steps, *state_shape) + noise_history = {} + for key, val in state.items(): + noise_history[key] = ( + keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt + ) + + def body(_loop_var, _loop_state): + _current_state, _current_time = _loop_state + _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} + new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) + return new_state, new_time + + final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) + return final_state diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 01ea5ad70..678029d92 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -3,6 +3,78 @@ from bayesflow.networks import MLP +@pytest.fixture() +def diffusion_model_edm_F(): + from bayesflow.experimental import DiffusionModel + + return DiffusionModel( + subnet=MLP([8, 8]), + integrate_kwargs={"method": "rk45", "steps": 250}, + noise_schedule="edm", + prediction_type="F", + ) + + +@pytest.fixture() +def diffusion_model_edm_velocity(): + from bayesflow.experimental import DiffusionModel + + return DiffusionModel( + subnet=MLP([8, 8]), + integrate_kwargs={"method": "rk45", "steps": 250}, + noise_schedule="edm", + prediction_type="velocity", + ) + + +@pytest.fixture() +def diffusion_model_edm_noise(): + from bayesflow.experimental import DiffusionModel + + return DiffusionModel( + subnet=MLP([8, 8]), + integrate_kwargs={"method": "rk45", "steps": 250}, + noise_schedule="edm", + prediction_type="noise", + ) + + +@pytest.fixture() +def diffusion_model_cosine_F(): + from bayesflow.experimental import DiffusionModel + + return DiffusionModel( + subnet=MLP([8, 8]), + integrate_kwargs={"method": "rk45", "steps": 250}, + noise_schedule="cosine", + prediction_type="F", + ) + + +@pytest.fixture() +def diffusion_model_cosine_velocity(): + from bayesflow.experimental import DiffusionModel + + return DiffusionModel( + subnet=MLP([8, 8]), + integrate_kwargs={"method": "rk45", "steps": 250}, + noise_schedule="cosine", + prediction_type="velocity", + ) + + +@pytest.fixture() +def diffusion_model_cosine_noise(): + from bayesflow.experimental import DiffusionModel + + return DiffusionModel( + subnet=MLP([8, 8]), + integrate_kwargs={"method": "rk45", "steps": 250}, + noise_schedule="cosine", + prediction_type="noise", + ) + + @pytest.fixture() def flow_matching(): from bayesflow.networks import FlowMatching @@ -86,6 +158,12 @@ def typical_point_inference_network_subnet(): "flow_matching", "free_form_flow", "consistency_model", + pytest.param("diffusion_model_edm_F", marks=pytest.mark.diffusion_model), + pytest.param("diffusion_model_edm_noise", marks=[pytest.mark.slow, pytest.mark.diffusion_model]), + pytest.param("diffusion_model_cosine_velocity", marks=[pytest.mark.slow, pytest.mark.diffusion_model]), + pytest.param("diffusion_model_cosine_F", marks=[pytest.mark.slow, pytest.mark.diffusion_model]), + pytest.param("diffusion_model_cosine_noise", marks=[pytest.mark.slow, pytest.mark.diffusion_model]), + pytest.param("diffusion_model_cosine_velocity", marks=[pytest.mark.slow, pytest.mark.diffusion_model]), ], scope="function", ) @@ -107,7 +185,47 @@ def inference_network_subnet(request): @pytest.fixture( - params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow", "consistency_model"], + params=[ + "affine_coupling_flow", + "spline_coupling_flow", + "flow_matching", + "free_form_flow", + "consistency_model", + pytest.param("diffusion_model_edm_F", marks=pytest.mark.diffusion_model), + pytest.param( + "diffusion_model_edm_noise", + marks=[ + pytest.mark.slow, + pytest.mark.diffusion_model, + pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."), + ], + ), + pytest.param("diffusion_model_cosine_velocity", marks=[pytest.mark.slow, pytest.mark.diffusion_model]), + pytest.param( + "diffusion_model_cosine_F", + marks=[ + pytest.mark.slow, + pytest.mark.diffusion_model, + pytest.mark.skip("skip to reduce load on CI."), + ], + ), + pytest.param( + "diffusion_model_cosine_noise", + marks=[ + pytest.mark.slow, + pytest.mark.diffusion_model, + pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."), + ], + ), + pytest.param( + "diffusion_model_cosine_velocity", + marks=[ + pytest.mark.slow, + pytest.mark.diffusion_model, + pytest.mark.skip("skip to reduce load on CI."), + ], + ), + ], scope="function", ) def generative_inference_network(request): diff --git a/tests/test_networks/test_diffusion_model/__init__.py b/tests/test_networks/test_diffusion_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_networks/test_diffusion_model/conftest.py b/tests/test_networks/test_diffusion_model/conftest.py new file mode 100644 index 000000000..72946c71c --- /dev/null +++ b/tests/test_networks/test_diffusion_model/conftest.py @@ -0,0 +1,23 @@ +import pytest + + +@pytest.fixture() +def cosine_noise_schedule(): + from bayesflow.experimental.diffusion_model import CosineNoiseSchedule + + return CosineNoiseSchedule(min_log_snr=-12, max_log_snr=12, shift=0.1, weighting="likelihood_weighting") + + +@pytest.fixture() +def edm_noise_schedule(): + from bayesflow.experimental.diffusion_model import EDMNoiseSchedule + + return EDMNoiseSchedule(sigma_data=10.0, sigma_min=1e-5, sigma_max=85.0) + + +@pytest.fixture( + params=["cosine_noise_schedule", "edm_noise_schedule"], + scope="function", +) +def noise_schedule(request): + return request.getfixturevalue(request.param) diff --git a/tests/test_networks/test_diffusion_model/test_diffusion_model.py b/tests/test_networks/test_diffusion_model/test_diffusion_model.py new file mode 100644 index 000000000..a0538a663 --- /dev/null +++ b/tests/test_networks/test_diffusion_model/test_diffusion_model.py @@ -0,0 +1,25 @@ +def test_serialize_deserialize_noise_schedule(noise_schedule): + from bayesflow.utils.serialization import serialize, deserialize + + serialized = serialize(noise_schedule) + deserialized = deserialize(serialized) + reserialized = serialize(deserialized) + + assert serialized == reserialized + t = 0.251 + x = 0.5 + training = True + assert noise_schedule.get_log_snr(t, training=training) == deserialized.get_log_snr(t, training=training) + assert noise_schedule.get_t_from_log_snr(t, training=training) == deserialized.get_t_from_log_snr( + t, training=training + ) + assert noise_schedule.derivative_log_snr(t, training=False) == deserialized.derivative_log_snr(t, training=False) + assert noise_schedule.get_drift_diffusion(t, x, training=False) == deserialized.get_drift_diffusion( + t, x, training=False + ) + assert noise_schedule.get_alpha_sigma(t, training=training) == deserialized.get_alpha_sigma(t, training=training) + assert noise_schedule.get_weights_for_snr(t) == deserialized.get_weights_for_snr(t) + + +def test_validate_noise_schedule(noise_schedule): + noise_schedule.validate() diff --git a/tests/test_utils/test_dispatch.py b/tests/test_utils/test_dispatch.py index df25ea78e..e8e497bc8 100644 --- a/tests/test_utils/test_dispatch.py +++ b/tests/test_utils/test_dispatch.py @@ -2,6 +2,7 @@ import pytest from bayesflow.utils import find_inference_network, find_distribution, find_summary_network +from bayesflow.experimental.diffusion_model import find_noise_schedule # --- Tests for find_inference_network.py --- @@ -168,3 +169,72 @@ def test_find_summary_network_unknown_name(): def test_find_summary_network_invalid_type(): with pytest.raises(TypeError): find_summary_network(0.1234) + + +def test_find_noise_schedule_by_name(): + from bayesflow.experimental.diffusion_model import CosineNoiseSchedule, EDMNoiseSchedule + + schedule = find_noise_schedule("cosine") + assert isinstance(schedule, CosineNoiseSchedule) + + schedule = find_noise_schedule("edm") + assert isinstance(schedule, EDMNoiseSchedule) + + +def test_find_noise_schedule_unknown_name(): + with pytest.raises(ValueError): + find_noise_schedule("unknown_noise_schedule") + + +def test_pass_noise_schedule(): + from bayesflow.experimental.diffusion_model import NoiseSchedule + + class CustomNoiseSchedule(NoiseSchedule): + def __init__(self): + pass + + def get_log_snr(self, t, training): + pass + + def get_t_from_log_snr(self, log_snr_t, training): + pass + + def derivative_log_snr(self, log_snr_t, training): + pass + + schedule = CustomNoiseSchedule() + assert schedule is find_noise_schedule(schedule) + + +def test_pass_noise_schedule_type(): + from bayesflow.experimental.diffusion_model import EDMNoiseSchedule + + schedule = find_noise_schedule(EDMNoiseSchedule, sigma_data=10.0) + assert isinstance(schedule, EDMNoiseSchedule) + assert schedule.sigma_data == 10.0 + + +def test_find_noise_schedule_by_dict(): + from bayesflow.experimental.diffusion_model import CosineNoiseSchedule, EDMNoiseSchedule + + schedule = find_noise_schedule({"name": "cosine"}) + assert isinstance(schedule, CosineNoiseSchedule) + + schedule = find_noise_schedule({"name": "edm", "sigma_data": 10}) + assert isinstance(schedule, EDMNoiseSchedule) + assert schedule.sigma_data == 10 + + +def test_find_noise_schedule_unknown_name_in_dict(): + with pytest.raises(ValueError): + find_noise_schedule({"name": "unknown_noise_schedule"}) + + +def test_find_noise_schedule_invalid_class(): + with pytest.raises(TypeError): + find_noise_schedule(int) + + +def test_find_noise_schedule_invalid_object(): + with pytest.raises(TypeError): + find_noise_schedule(1.0)