Skip to content

Commit f2d7de4

Browse files
committed
fix backend
1 parent e840046 commit f2d7de4

File tree

1 file changed

+1
-40
lines changed

1 file changed

+1
-40
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,51 +13,13 @@
1313
expand_right_as,
1414
find_network,
1515
jacobian_trace,
16-
keras_kwargs,
1716
serialize_value_or_type,
1817
deserialize_value_or_type,
1918
weighted_mean,
2019
integrate,
2120
)
2221

2322

24-
match keras.backend.backend():
25-
case "jax":
26-
from jax.scipy.special import erf, erfinv
27-
28-
def cdf_gaussian(x, loc, scale):
29-
return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0))))
30-
31-
def icdf_gaussian(x, loc, scale):
32-
return loc + scale * erfinv(2 * x - 1) * math.sqrt(2)
33-
case "numpy":
34-
from scipy.special import erf, erfinv
35-
36-
def cdf_gaussian(x, loc, scale):
37-
return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0))))
38-
39-
def icdf_gaussian(x, loc, scale):
40-
return loc + scale * erfinv(2 * x - 1) * math.sqrt(2.0)
41-
case "tensorflow":
42-
from tensorflow.math import erf, erfinv
43-
44-
def cdf_gaussian(x, loc, scale):
45-
return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0))))
46-
47-
def icdf_gaussian(x, loc, scale):
48-
return loc + scale * erfinv(2 * x - 1) * math.sqrt(2.0)
49-
case "torch":
50-
from torch import erf, erfinv
51-
52-
def cdf_gaussian(x, loc, scale):
53-
return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0))))
54-
55-
def icdf_gaussian(x, loc, scale):
56-
return loc + scale * erfinv(2 * x - 1) * math.sqrt(2.0)
57-
case other:
58-
raise ValueError(f"Backend '{other}' is not supported.")
59-
60-
6123
class NoiseSchedule(ABC):
6224
"""Noise schedule for diffusion models. We follow the notation from [1].
6325
@@ -401,8 +363,7 @@ def __init__(
401363
**kwargs
402364
Additional keyword arguments passed to the subnet and other components.
403365
"""
404-
405-
super().__init__(base_distribution=None, **keras_kwargs(kwargs))
366+
super().__init__(base_distribution="normal", **kwargs)
406367

407368
if isinstance(noise_schedule, str):
408369
if noise_schedule == "linear":

0 commit comments

Comments
 (0)