|
13 | 13 | expand_right_as, |
14 | 14 | find_network, |
15 | 15 | jacobian_trace, |
16 | | - keras_kwargs, |
17 | 16 | serialize_value_or_type, |
18 | 17 | deserialize_value_or_type, |
19 | 18 | weighted_mean, |
20 | 19 | integrate, |
21 | 20 | ) |
22 | 21 |
|
23 | 22 |
|
24 | | -match keras.backend.backend(): |
25 | | - case "jax": |
26 | | - from jax.scipy.special import erf, erfinv |
27 | | - |
28 | | - def cdf_gaussian(x, loc, scale): |
29 | | - return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0)))) |
30 | | - |
31 | | - def icdf_gaussian(x, loc, scale): |
32 | | - return loc + scale * erfinv(2 * x - 1) * math.sqrt(2) |
33 | | - case "numpy": |
34 | | - from scipy.special import erf, erfinv |
35 | | - |
36 | | - def cdf_gaussian(x, loc, scale): |
37 | | - return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0)))) |
38 | | - |
39 | | - def icdf_gaussian(x, loc, scale): |
40 | | - return loc + scale * erfinv(2 * x - 1) * math.sqrt(2.0) |
41 | | - case "tensorflow": |
42 | | - from tensorflow.math import erf, erfinv |
43 | | - |
44 | | - def cdf_gaussian(x, loc, scale): |
45 | | - return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0)))) |
46 | | - |
47 | | - def icdf_gaussian(x, loc, scale): |
48 | | - return loc + scale * erfinv(2 * x - 1) * math.sqrt(2.0) |
49 | | - case "torch": |
50 | | - from torch import erf, erfinv |
51 | | - |
52 | | - def cdf_gaussian(x, loc, scale): |
53 | | - return 0.5 * (1 + erf((x - loc) / (scale * math.sqrt(2.0)))) |
54 | | - |
55 | | - def icdf_gaussian(x, loc, scale): |
56 | | - return loc + scale * erfinv(2 * x - 1) * math.sqrt(2.0) |
57 | | - case other: |
58 | | - raise ValueError(f"Backend '{other}' is not supported.") |
59 | | - |
60 | | - |
61 | 23 | class NoiseSchedule(ABC): |
62 | 24 | """Noise schedule for diffusion models. We follow the notation from [1]. |
63 | 25 |
|
@@ -401,8 +363,7 @@ def __init__( |
401 | 363 | **kwargs |
402 | 364 | Additional keyword arguments passed to the subnet and other components. |
403 | 365 | """ |
404 | | - |
405 | | - super().__init__(base_distribution=None, **keras_kwargs(kwargs)) |
| 366 | + super().__init__(base_distribution="normal", **kwargs) |
406 | 367 |
|
407 | 368 | if isinstance(noise_schedule, str): |
408 | 369 | if noise_schedule == "linear": |
|
0 commit comments