77
88from bayesflow .utils import keras_kwargs , find_network , serialize_value_or_type , deserialize_value_or_type
99from bayesflow .types import Shape , Tensor
10- from bayesflow .scores import ScoringRule , ParametricDistributionRule
10+ from bayesflow .scores import ScoringRule , ParametricDistributionScore
1111from bayesflow .utils .decorators import allow_batch_size
1212
1313
@@ -98,9 +98,6 @@ def get_build_config(self):
9898 heads [score_key ] = {}
9999 for head_key , head in self .heads [score_key ].items ():
100100 heads [score_key ][head_key ] = head .name
101- # Alternatively, save full build config of head
102- # heads[score_key][head_key] = head.get_build_config()
103- # TODO: decide
104101
105102 build_config ["heads" ] = heads
106103
@@ -113,35 +110,6 @@ def build_from_config(self, config):
113110 for head_key , head in self .heads [score_key ].items ():
114111 head .name = config ["heads" ][score_key ][head_key ]
115112
116- # Alternatively, do NOT call self.build, but rather imitate it using the build config of each head
117- # This results in some code duplication with self.build and requires heads to be of a custom type.
118- # TODO: decide
119-
120- # input_shape = config["conditions_shape"]
121- #
122- # # Save input_shape for usage in get_build_config
123- # self._input_shape = input_shape
124- #
125- # # build the shared body network
126- # self.subnet.build(input_shape)
127- #
128- # # build head(s) for every scoring rule
129- # self.heads = dict()
130- # self.heads_flat = dict()
131- #
132- # for score_key in self.scores.keys():
133- #
134- # self.heads[score_key] = {}
135- #
136- # for head_key, head_config in config["heads"][score_key].items():
137- # head = keras.Sequential()
138- # head.build_from_config(head_config) # TODO: this method is not implemented yet
139- # it would require the head to be a
140- # custom object rather than a Sequential
141- # self.heads[score_key][head_key] = head
142- # flat_key = f"{score_key}___{head_key}"
143- # self.heads_flat[flat_key] = head
144-
145113 def get_config (self ):
146114 base_config = super ().get_config ()
147115
@@ -202,7 +170,7 @@ def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> dic
202170 samples = {}
203171
204172 for score_key , score in self .scores .items ():
205- if isinstance (score , ParametricDistributionRule ):
173+ if isinstance (score , ParametricDistributionScore ):
206174 parameters = {head_key : head (output ) for head_key , head in self .heads [score_key ].items ()}
207175 samples [score_key ] = score .sample (batch_shape , ** parameters )
208176
@@ -214,7 +182,7 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> dict
214182 log_probs = {}
215183
216184 for score_key , score in self .scores .items ():
217- if isinstance (score , ParametricDistributionRule ):
185+ if isinstance (score , ParametricDistributionScore ):
218186 parameters = {head_key : head (output ) for head_key , head in self .heads [score_key ].items ()}
219187 log_probs [score_key ] = score .log_prob (x = samples , ** parameters )
220188
0 commit comments