Skip to content

Commit d82e2bf

Browse files
committed
add missing seed_generator param
1 parent 01b33dc commit d82e2bf

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import ABC, abstractmethod
33
import keras
44
from keras import ops
5+
import warnings
56

67
from bayesflow.utils.serialization import serialize, deserialize, serializable
78
from bayesflow.types import Tensor, Shape
@@ -389,7 +390,7 @@ def __init__(
389390
**kwargs
390391
Additional keyword arguments passed to the subnet and other components.
391392
"""
392-
super().__init__(base_distribution="normal", **kwargs)
393+
super().__init__(base_distribution=None, **kwargs)
393394

394395
if isinstance(noise_schedule, str):
395396
if noise_schedule == "linear":
@@ -419,6 +420,13 @@ def __init__(
419420
self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
420421
self.seed_generator = keras.random.SeedGenerator()
421422

423+
if subnet_kwargs:
424+
warnings.warn(
425+
"Using `subnet_kwargs` is deprecated."
426+
"Instead, instantiate the network yourself and pass the arguments directly.",
427+
DeprecationWarning,
428+
)
429+
422430
subnet_kwargs = subnet_kwargs or {}
423431
if subnet == "mlp":
424432
subnet_kwargs = self.MLP_DEFAULT_CONFIG | subnet_kwargs
@@ -643,7 +651,7 @@ def compute_metrics(
643651

644652
# sample training diffusion time as low discrepancy sequence to decrease variance
645653
# t_i = \mod (u_0 + i/k, 1)
646-
u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x))
654+
u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator)
647655
i = ops.arange(0, keras.ops.shape(x)[0], dtype=ops.dtype(x)) # tensor of indices
648656
t = (u0 + i / ops.cast(keras.ops.shape(x)[0], dtype=ops.dtype(x))) % 1
649657
# i = keras.random.randint((keras.ops.shape(x)[0],), minval=0, maxval=self._timesteps)

0 commit comments

Comments
 (0)