We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 18f3a6e commit 7955cbdCopy full SHA for 7955cbd
bayesflow/networks/inference_network.py
@@ -1,15 +1,15 @@
1
import keras
2
3
from bayesflow.types import Shape, Tensor
4
-from bayesflow.utils import find_distribution
+from bayesflow.utils import find_distribution, keras_kwargs
5
from bayesflow.utils.decorators import allow_batch_size
6
7
8
class InferenceNetwork(keras.Layer):
9
MLP_DEFAULT_CONFIG = {}
10
11
def __init__(self, base_distribution: str = "normal", **kwargs):
12
- super().__init__(**kwargs)
+ super().__init__(**keras_kwargs(kwargs))
13
self.base_distribution = find_distribution(base_distribution)
14
15
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
0 commit comments