|
7 | 7 |
|
8 | 8 | from bayesflow.utils.serialization import serialize, deserialize, serializable |
9 | 9 | from bayesflow.types import Tensor, Shape |
10 | | -import bayesflow as bf |
11 | 10 | from bayesflow.networks import InferenceNetwork |
12 | 11 | import math |
13 | 12 |
|
@@ -334,7 +333,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor: |
334 | 333 |
|
335 | 334 | def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor: |
336 | 335 | """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).""" |
337 | | - return (ops.exp(-log_snr_t) + ops.square(self.sigma_data)) / ops.square(self.sigma_data) |
| 336 | + return ops.exp(-log_snr_t) + ops.square(self.sigma_data) # / ops.square(self.sigma_data) |
338 | 337 |
|
339 | 338 | def get_config(self): |
340 | 339 | return dict(sigma_data=self.sigma_data, sigma_min=self.sigma_min, sigma_max=self.sigma_max) |
@@ -403,7 +402,7 @@ def __init__( |
403 | 402 | **kwargs |
404 | 403 | Additional keyword arguments passed to the subnet and other components. |
405 | 404 | """ |
406 | | - super().__init__(base_distribution=None, **kwargs) |
| 405 | + super().__init__(base_distribution="normal", **kwargs) |
407 | 406 |
|
408 | 407 | if isinstance(noise_schedule, str): |
409 | 408 | if noise_schedule == "linear": |
@@ -433,7 +432,6 @@ def __init__( |
433 | 432 | self._clip_max = 5.0 |
434 | 433 |
|
435 | 434 | # latent distribution (not configurable) |
436 | | - self.base_distribution = bf.distributions.DiagonalNormal() |
437 | 435 | self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {}) |
438 | 436 | self.seed_generator = keras.random.SeedGenerator() |
439 | 437 |
|
|
0 commit comments