Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
549a055
Add diffusion model implementation, EDM variant
vpratz Apr 13, 2025
630a823
adding more noise schedules
arrjon Apr 16, 2025
c1cb183
adding noise scheduler class
arrjon Apr 23, 2025
49c0cb7
adding noise scheduler class
arrjon Apr 23, 2025
5f11724
Merge branch 'main' into feat-diffusion-model
arrjon Apr 23, 2025
280b651
Merge branch 'dev' into feat-diffusion-model
vpratz Apr 24, 2025
e840046
fix backend
arrjon Apr 24, 2025
f2d7de4
fix backend
arrjon Apr 24, 2025
d5dc2ba
wip: adapt network to layer paradigm
vpratz Apr 24, 2025
739491a
improve schedules
arrjon Apr 24, 2025
efeff85
Merge branch 'feat-diffusion-model-adapt' into feat-diffusion-model
vpratz Apr 24, 2025
92131d7
add serialization, remove unnecessary tensor conversions
vpratz Apr 24, 2025
bd564b5
format inference network conftest.py
vpratz Apr 24, 2025
0f7b3f5
add dtypes and type casts in compute_metrics
vpratz Apr 24, 2025
2ce74f0
disable clip on x by default
vpratz Apr 24, 2025
01b33dc
fixes: use squared g, correct typo in _min_t
vpratz Apr 24, 2025
6031212
integration should be from 1 to 0
arrjon Apr 24, 2025
d82e2bf
add missing seed_generator param
vpratz Apr 24, 2025
d8d6246
Merge branch 'feat-diffusion-model' of github.com:bayesflow-org/bayes…
vpratz Apr 24, 2025
bdb27e8
correct integration times for forward direction
vpratz Apr 24, 2025
ca52fc0
flip integration times for correct direction of integration
vpratz Apr 24, 2025
cbd3568
swap mapping log_snr_min/max to t_min/max
vpratz Apr 24, 2025
9b520bc
fix mapping min/max snr to t_min/max
arrjon Apr 24, 2025
3757c9d
Merge remote-tracking branch 'origin/feat-diffusion-model' into feat-…
arrjon Apr 24, 2025
e32e8ad
fix linear schedule
arrjon Apr 24, 2025
3455ce1
rename prediction type
arrjon Apr 24, 2025
95ca126
fix: remove unnecessary covert_to_tensor call
vpratz Apr 24, 2025
495ed29
fix validate noise schedule for training
arrjon Apr 24, 2025
59a349b
minor change in diffusion weightings
arrjon Apr 24, 2025
612b17b
add euler_maruyama sampler
arrjon Apr 24, 2025
de532c7
abs step size
arrjon Apr 24, 2025
9ed482d
stochastic sampler
arrjon Apr 24, 2025
2fd5a90
Merge pull request #440 from bayesflow-org/feat-stochastic-sampler
arrjon Apr 24, 2025
548f51b
stochastic sampler fix
arrjon Apr 25, 2025
194a503
fix scale base dist
arrjon Apr 25, 2025
196683c
EDM training bounds
arrjon Apr 25, 2025
5b52499
minor changes
arrjon Apr 25, 2025
eb96620
fix base distribution
arrjon Apr 25, 2025
668f6fc
seed in stochastic sampler
arrjon Apr 25, 2025
1a970c2
seed in stochastic sampler
arrjon Apr 25, 2025
ebafc5e
seed in stochastic sampler
arrjon Apr 25, 2025
9941fa3
seed in stochastic sampler
arrjon Apr 25, 2025
afaebef
seed in stochastic sampler
arrjon Apr 25, 2025
c1558c5
seed in stochastic sampler
arrjon Apr 25, 2025
1efd88f
fix is_symbolic_tensor
LarsKue Apr 25, 2025
7456cdb
[skip ci] skip step_fn for tracing (dangerous, subject to removal)
LarsKue Apr 25, 2025
a722729
seed in stochastic sampler
arrjon Apr 26, 2025
ee0c87b
seed in stochastic sampler
arrjon Apr 26, 2025
f2cbde6
fix loss
arrjon Apr 28, 2025
7b7b15a
fix loss
arrjon Apr 28, 2025
1811038
improve schedules
arrjon Apr 28, 2025
9d13264
improve schedules
arrjon Apr 28, 2025
4e0b7f8
improve edm
arrjon Apr 28, 2025
a028e8a
temporary: add notebook to compare implementations
vpratz Apr 29, 2025
1f15b7d
Merge remote-tracking branch 'upstream/dev' into feat-diffusion-model
vpratz Apr 29, 2025
6794342
add loss types
arrjon Apr 29, 2025
7c527a5
add loss types
arrjon Apr 29, 2025
5ca609f
scale snr
arrjon Apr 29, 2025
79be9ab
fix stochastic sampler
arrjon Apr 29, 2025
f235671
remove notebook used for testing the implementations
vpratz Apr 30, 2025
e380f5e
cleanup: remove linear schedule, minor fixes
vpratz Apr 30, 2025
8402a3f
add diffusion models to inference network tests
vpratz Apr 30, 2025
78814ac
[no ci] remove contradictory comment in EDM schedule
vpratz May 2, 2025
6572f06
Merge branch 'dev' into feat-diffusion-model
arrjon May 7, 2025
a736578
improved docs, comments, typing
arrjon May 7, 2025
b5d6f0f
fix noise schedule dispatch
arrjon May 7, 2025
0781032
fix serializable
arrjon May 7, 2025
0e5a48c
fix Literal
arrjon May 7, 2025
d43ea98
Move all files into a separate directory, fix serialization
vpratz May 8, 2025
0b5a800
docstring formatting
vpratz May 8, 2025
e8d34d7
adapt test: use subnet arg, increase number of integration steps
vpratz May 8, 2025
7b55a37
fix stochastic sampler
arrjon May 8, 2025
2566eb4
fix TypeError
arrjon May 8, 2025
1cb5547
Merge remote-tracking branch 'upstream/dev' into feat-diffusion-model
vpratz May 8, 2025
3ee8582
Merge branch 'refs/heads/dev' into feat-diffusion-model
arrjon May 9, 2025
0f96265
base class noise schedule
arrjon May 9, 2025
f4f0d11
improve dispatch
arrjon May 9, 2025
b9a7e36
restructure noise schedules, test find_noise_schedule
vpratz May 11, 2025
a387ed9
add serialization tests, minor fix to cosine schedule config
vpratz May 11, 2025
7253fa8
[no ci] skip to variants in tests, as the tests take very long
vpratz May 11, 2025
051b515
Merge remote-tracking branch 'upstream/dev' into feat-diffusion-model
vpratz May 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bayesflow/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from .cif import CIF
from .continuous_time_consistency_model import ContinuousTimeConsistencyModel
from .diffusion_model import DiffusionModel
from .free_form_flow import FreeFormFlow

from ..utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
_add_imports_to_all(include_modules=["diffusion_model"])
9 changes: 9 additions & 0 deletions bayesflow/experimental/diffusion_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .diffusion_model import DiffusionModel
from .noise_schedule import NoiseSchedule
from .cosine_noise_schedule import CosineNoiseSchedule
from .edm_noise_schedule import EDMNoiseSchedule
from .dispatch import find_noise_schedule

from ...utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
85 changes: 85 additions & 0 deletions bayesflow/experimental/diffusion_model/cosine_noise_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import math
from typing import Union, Literal

from keras import ops

from bayesflow.types import Tensor
from bayesflow.utils.serialization import deserialize, serializable

from .noise_schedule import NoiseSchedule


# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class CosineNoiseSchedule(NoiseSchedule):
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].

[1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
"""

def __init__(
self,
min_log_snr: float = -15,
max_log_snr: float = 15,
shift: float = 0.0,
weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid",
):
"""
Initialize the cosine noise schedule.

Parameters
----------
min_log_snr : float, optional
The minimum log signal-to-noise ratio (lambda). Default is -15.
max_log_snr : float, optional
The maximum log signal-to-noise ratio (lambda). Default is 15.
shift : float, optional
Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0.
For images, use shift = log(base_resolution / d), where d is the used resolution of the image.
weighting : Literal["sigmoid", "likelihood_weighting"], optional
The type of weighting function to use for the noise schedule. Default is "sigmoid".
"""
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
self._shift = shift
self._weighting = weighting
self.log_snr_min = min_log_snr
self.log_snr_max = max_log_snr

self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True)
self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True)

def _truncated_t(self, t: Tensor) -> Tensor:
return self._t_min + (self._t_max - self._t_min) * t

def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
t_trunc = self._truncated_t(t)
return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift

def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5))

def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training)

# Compute the truncated time t_trunc
t_trunc = self._truncated_t(t)
dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc)

# Using the chain rule on f(t) = log(1 + e^(-snr(t))):
# f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
dsnr_dt = dsnr_dx * (self._t_max - self._t_min)
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
return -factor * dsnr_dt

def get_config(self):
return dict(
min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift, weighting=self._weighting
)

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))
Loading