Skip to content

Commit 161593c

Browse files
committed
not working: first steps to achieve serialization with the new functions
1 parent 662614e commit 161593c

File tree

6 files changed

+40
-22
lines changed

6 files changed

+40
-22
lines changed

bayesflow/links/ordered_quantiles.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import keras
22

33
from bayesflow.utils import layer_kwargs, logging
4-
from bayesflow.utils.serialization import serializable
4+
from bayesflow.utils.serialization import serializable, serialize, deserialize
55

66
from collections.abc import Sequence
77

@@ -21,7 +21,11 @@ def get_config(self):
2121
config = {
2222
"q": self.q,
2323
}
24-
return base_config | config
24+
return base_config | serialize(config)
25+
26+
@classmethod
27+
def from_config(cls, config):
28+
return cls(**deserialize(config))
2529

2630
def build(self, input_shape):
2731
if self.axis is None and 1 < len(input_shape) <= 3:

bayesflow/networks/point_inference_network.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@ def __init__(
2222
super().__init__(**model_kwargs(kwargs))
2323

2424
self.scores = scores
25+
2526
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
2627

27-
self._kwargs = kwargs
28+
self.config = {
29+
"subnet": serialize(subnet),
30+
"scores": serialize(scores),
31+
**kwargs,
32+
}
2833

2934
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
3035
"""Builds all network components based on shapes of conditions and targets.
@@ -103,18 +108,14 @@ def build_from_config(self, config):
103108

104109
def get_config(self):
105110
base_config = super().get_config()
106-
config = {
107-
"scores": self.scores,
108-
"subnet": self.subnet,
109-
**self._kwargs,
110-
}
111111

112-
return base_config | serialize(config)
112+
return base_config | self.config
113113

114114
@classmethod
115115
def from_config(cls, config):
116116
config = config.copy()
117-
return cls(**deserialize(config))
117+
config = deserialize(config)
118+
return cls(**config)
118119

119120
def call(
120121
self,

bayesflow/scores/mean_score.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from keras.saving import register_keras_serializable as serializable
1+
from bayesflow.utils.serialization import serializable
22

33
from .normed_difference_score import NormedDifferenceScore
44

@@ -12,4 +12,6 @@ class MeanScore(NormedDifferenceScore):
1212

1313
def __init__(self, **kwargs):
1414
super().__init__(k=2, **kwargs)
15-
self.config = {}
15+
16+
def get_config(self):
17+
return {}

bayesflow/scores/median_score.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from keras.saving import register_keras_serializable as serializable
1+
from bayesflow.utils.serialization import serializable
22

33
from .normed_difference_score import NormedDifferenceScore
44

@@ -12,4 +12,6 @@ class MedianScore(NormedDifferenceScore):
1212

1313
def __init__(self, **kwargs):
1414
super().__init__(k=1, **kwargs)
15-
self.config = {}
15+
16+
def get_config(self):
17+
return {}

bayesflow/scores/multivariate_normal_score.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import math
22

33
import keras
4-
from keras.saving import register_keras_serializable as serializable
54

65
from bayesflow.types import Shape, Tensor
76
from bayesflow.links import PositiveDefinite
7+
from bayesflow.utils.serialization import deserialize, serializable, serialize
88

99
from .parametric_distribution_score import ParametricDistributionScore
1010

@@ -32,11 +32,17 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
3232
self.dim = dim
3333
self.links = links or {"covariance": PositiveDefinite()}
3434

35-
self.config = {"dim": dim}
36-
3735
def get_config(self):
3836
base_config = super().get_config()
39-
return base_config | self.config
37+
config = {
38+
"dim": self.dim,
39+
"links": self.links,
40+
}
41+
return base_config | serialize(config)
42+
43+
@classmethod
44+
def from_config(cls, config):
45+
return cls(**deserialize(config))
4046

4147
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
4248
self.dim = target_shape[-1]

bayesflow/scores/normed_difference_score.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
2+
from bayesflow.utils.serialization import deserialize, serializable
33

44
from bayesflow.types import Shape, Tensor
55
from bayesflow.utils import weighted_mean
@@ -20,8 +20,6 @@ def __init__(self, k: int, **kwargs):
2020
#: Exponent to absolute difference
2121
self.k = k
2222

23-
self.config = {"k": k}
24-
2523
def get_head_shapes_from_target_shape(self, target_shape: Shape):
2624
# keras.saving.load_model sometimes passes target_shape as a list, so we force a conversion
2725
target_shape = tuple(target_shape)
@@ -61,4 +59,9 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
6159

6260
def get_config(self):
6361
base_config = super().get_config()
64-
return base_config | self.config
62+
config = dict(k=self.k)
63+
return base_config | config
64+
65+
@classmethod
66+
def from_config(cls, config):
67+
return cls(**deserialize(config))

0 commit comments

Comments
 (0)