|
2 | 2 | from abc import ABC, abstractmethod |
3 | 3 | import keras |
4 | 4 | from keras import ops |
| 5 | +import warnings |
5 | 6 |
|
6 | 7 | from bayesflow.utils.serialization import serialize, deserialize, serializable |
7 | 8 | from bayesflow.types import Tensor, Shape |
@@ -389,7 +390,7 @@ def __init__( |
389 | 390 | **kwargs |
390 | 391 | Additional keyword arguments passed to the subnet and other components. |
391 | 392 | """ |
392 | | - super().__init__(base_distribution="normal", **kwargs) |
| 393 | + super().__init__(base_distribution=None, **kwargs) |
393 | 394 |
|
394 | 395 | if isinstance(noise_schedule, str): |
395 | 396 | if noise_schedule == "linear": |
@@ -419,6 +420,13 @@ def __init__( |
419 | 420 | self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {}) |
420 | 421 | self.seed_generator = keras.random.SeedGenerator() |
421 | 422 |
|
| 423 | + if subnet_kwargs: |
| 424 | + warnings.warn( |
| 425 | + "Using `subnet_kwargs` is deprecated." |
| 426 | + "Instead, instantiate the network yourself and pass the arguments directly.", |
| 427 | + DeprecationWarning, |
| 428 | + ) |
| 429 | + |
422 | 430 | subnet_kwargs = subnet_kwargs or {} |
423 | 431 | if subnet == "mlp": |
424 | 432 | subnet_kwargs = self.MLP_DEFAULT_CONFIG | subnet_kwargs |
@@ -643,7 +651,7 @@ def compute_metrics( |
643 | 651 |
|
644 | 652 | # sample training diffusion time as low discrepancy sequence to decrease variance |
645 | 653 | # t_i = \mod (u_0 + i/k, 1) |
646 | | - u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x)) |
| 654 | + u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator) |
647 | 655 | i = ops.arange(0, keras.ops.shape(x)[0], dtype=ops.dtype(x)) # tensor of indices |
648 | 656 | t = (u0 + i / ops.cast(keras.ops.shape(x)[0], dtype=ops.dtype(x))) % 1 |
649 | 657 | # i = keras.random.randint((keras.ops.shape(x)[0],), minval=0, maxval=self._timesteps) |
|
0 commit comments