Skip to content

Commit ebafc5e

Browse files
committed
seed in stochastic sampler
1 parent 1a970c2 commit ebafc5e

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from bayesflow.utils.serialization import serialize, deserialize, serializable
99
from bayesflow.types import Tensor, Shape
10-
import bayesflow as bf
1110
from bayesflow.networks import InferenceNetwork
1211
import math
1312

@@ -334,7 +333,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
334333

335334
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
336335
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
337-
return (ops.exp(-log_snr_t) + ops.square(self.sigma_data)) / ops.square(self.sigma_data)
336+
return ops.exp(-log_snr_t) + ops.square(self.sigma_data) # / ops.square(self.sigma_data)
338337

339338
def get_config(self):
340339
return dict(sigma_data=self.sigma_data, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
@@ -403,7 +402,7 @@ def __init__(
403402
**kwargs
404403
Additional keyword arguments passed to the subnet and other components.
405404
"""
406-
super().__init__(base_distribution=None, **kwargs)
405+
super().__init__(base_distribution="normal", **kwargs)
407406

408407
if isinstance(noise_schedule, str):
409408
if noise_schedule == "linear":
@@ -433,7 +432,6 @@ def __init__(
433432
self._clip_max = 5.0
434433

435434
# latent distribution (not configurable)
436-
self.base_distribution = bf.distributions.DiagonalNormal()
437435
self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
438436
self.seed_generator = keras.random.SeedGenerator()
439437

bayesflow/utils/integrate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,18 +423,18 @@ def integrate_stochastic(
423423
time = start_time
424424

425425
def body(_loop_var, _loop_state):
426-
_state, _time = _loop_state
426+
_state, _time, _seed = _loop_state
427427

428428
# Generate noise for this step
429429
_noise = {}
430430
for key in _state.keys():
431431
shape = keras.ops.shape(_state[key])
432-
_noise[key] = keras.random.normal(shape, seed=seed) * keras.ops.sqrt(keras.ops.abs(step_size))
432+
_noise[key] = keras.random.normal(shape, seed=_seed) * keras.ops.sqrt(keras.ops.abs(step_size))
433433

434434
# Perform integration step
435435
_state, _time, _ = step_fn(_state, _time, step_size, noise=_noise)
436436

437-
return _state, _time
437+
return _state, _time, _seed
438438

439-
state, time = keras.ops.fori_loop(0, steps, body, (state, time))
439+
state, time = keras.ops.fori_loop(0, steps, body, (state, time, seed))
440440
return state

0 commit comments

Comments
 (0)