Skip to content

Commit 7955cbd

Browse files
committed
[no-ci] fix: filter kwargs of InferenceNetwork
1 parent 18f3a6e commit 7955cbd

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bayesflow/networks/inference_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import keras
22

33
from bayesflow.types import Shape, Tensor
4-
from bayesflow.utils import find_distribution
4+
from bayesflow.utils import find_distribution, keras_kwargs
55
from bayesflow.utils.decorators import allow_batch_size
66

77

88
class InferenceNetwork(keras.Layer):
99
MLP_DEFAULT_CONFIG = {}
1010

1111
def __init__(self, base_distribution: str = "normal", **kwargs):
12-
super().__init__(**kwargs)
12+
super().__init__(**keras_kwargs(kwargs))
1313
self.base_distribution = find_distribution(base_distribution)
1414

1515
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:

0 commit comments

Comments
 (0)