Skip to content

Commit b206e5a

Browse files
committed
Revert point inference net
1 parent d8ae2d1 commit b206e5a

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import keras
2+
from keras.saving import (
3+
deserialize_keras_object as deserialize,
4+
serialize_keras_object as serialize,
5+
register_keras_serializable as serializable,
6+
)
27

3-
8+
from bayesflow.utils import model_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type
49
from bayesflow.types import Shape, Tensor
510
from bayesflow.scores import ScoringRule, ParametricDistributionScore
611
from bayesflow.utils.decorators import allow_batch_size
7-
from bayesflow.utils import model_kwargs, find_network
8-
from bayesflow.utils.serialization import serialize, deserialize, serializable
912

1013

11-
@serializable
14+
@serializable(package="networks.point_inference_network")
1215
class PointInferenceNetwork(keras.Model):
1316
"""Implements point estimation for user specified scoring rules by a shared feed forward architecture
1417
with separate heads for each scoring rule.
@@ -23,8 +26,15 @@ def __init__(
2326
super().__init__(**model_kwargs(kwargs))
2427

2528
self.scores = scores
29+
2630
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
2731

32+
self.config = {
33+
**kwargs,
34+
}
35+
self.config = serialize_value_or_type(self.config, "subnet", subnet)
36+
self.config["scores"] = serialize(self.scores)
37+
2838
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
2939
"""Builds all network components based on shapes of conditions and targets.
3040
@@ -102,18 +112,15 @@ def build_from_config(self, config):
102112

103113
def get_config(self):
104114
base_config = super().get_config()
105-
base_config = model_kwargs(base_config)
106-
107-
config = {
108-
"scores": self.scores,
109-
"subnet": self.subnet,
110-
}
111115

112-
return base_config | serialize(config)
116+
return base_config | self.config
113117

114118
@classmethod
115-
def from_config(cls, config, custom_objects=None):
116-
return cls(**deserialize(config, custom_objects=custom_objects))
119+
def from_config(cls, config):
120+
config = config.copy()
121+
config["scores"] = deserialize(config["scores"])
122+
config = deserialize_value_or_type(config, "subnet")
123+
return cls(**config)
117124

118125
def call(
119126
self,

0 commit comments

Comments
 (0)