11import keras
2+ from keras .saving import (
3+ deserialize_keras_object as deserialize ,
4+ serialize_keras_object as serialize ,
5+ register_keras_serializable as serializable ,
6+ )
27
3-
8+ from bayesflow . utils import model_kwargs , find_network , serialize_value_or_type , deserialize_value_or_type
49from bayesflow .types import Shape , Tensor
510from bayesflow .scores import ScoringRule , ParametricDistributionScore
611from bayesflow .utils .decorators import allow_batch_size
7- from bayesflow .utils import model_kwargs , find_network
8- from bayesflow .utils .serialization import serialize , deserialize , serializable
912
1013
11- @serializable
14+ @serializable ( package = "networks.point_inference_network" )
1215class PointInferenceNetwork (keras .Model ):
1316 """Implements point estimation for user specified scoring rules by a shared feed forward architecture
1417 with separate heads for each scoring rule.
@@ -23,8 +26,15 @@ def __init__(
2326 super ().__init__ (** model_kwargs (kwargs ))
2427
2528 self .scores = scores
29+
2630 self .subnet = find_network (subnet , ** kwargs .get ("subnet_kwargs" , {}))
2731
32+ self .config = {
33+ ** kwargs ,
34+ }
35+ self .config = serialize_value_or_type (self .config , "subnet" , subnet )
36+ self .config ["scores" ] = serialize (self .scores )
37+
2838 def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
2939 """Builds all network components based on shapes of conditions and targets.
3040
@@ -102,18 +112,15 @@ def build_from_config(self, config):
102112
103113 def get_config (self ):
104114 base_config = super ().get_config ()
105- base_config = model_kwargs (base_config )
106-
107- config = {
108- "scores" : self .scores ,
109- "subnet" : self .subnet ,
110- }
111115
112- return base_config | serialize ( config )
116+ return base_config | self . config
113117
114118 @classmethod
115- def from_config (cls , config , custom_objects = None ):
116- return cls (** deserialize (config , custom_objects = custom_objects ))
119+ def from_config (cls , config ):
120+ config = config .copy ()
121+ config ["scores" ] = deserialize (config ["scores" ])
122+ config = deserialize_value_or_type (config , "subnet" )
123+ return cls (** config )
117124
118125 def call (
119126 self ,
0 commit comments