Skip to content

Commit bff8d20

Browse files
committed
Remove comments wrt serialization proposal and fix score name change
1 parent 94720f2 commit bff8d20

File tree

1 file changed

+3
-35
lines changed

1 file changed

+3
-35
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from bayesflow.utils import keras_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type
99
from bayesflow.types import Shape, Tensor
10-
from bayesflow.scores import ScoringRule, ParametricDistributionRule
10+
from bayesflow.scores import ScoringRule, ParametricDistributionScore
1111
from bayesflow.utils.decorators import allow_batch_size
1212

1313

@@ -98,9 +98,6 @@ def get_build_config(self):
9898
heads[score_key] = {}
9999
for head_key, head in self.heads[score_key].items():
100100
heads[score_key][head_key] = head.name
101-
# Alternatively, save full build config of head
102-
# heads[score_key][head_key] = head.get_build_config()
103-
# TODO: decide
104101

105102
build_config["heads"] = heads
106103

@@ -113,35 +110,6 @@ def build_from_config(self, config):
113110
for head_key, head in self.heads[score_key].items():
114111
head.name = config["heads"][score_key][head_key]
115112

116-
# Alternatively, do NOT call self.build, but rather imitate it using the build config of each head
117-
# This results in some code duplication with self.build and requires heads to be of a custom type.
118-
# TODO: decide
119-
120-
# input_shape = config["conditions_shape"]
121-
#
122-
# # Save input_shape for usage in get_build_config
123-
# self._input_shape = input_shape
124-
#
125-
# # build the shared body network
126-
# self.subnet.build(input_shape)
127-
#
128-
# # build head(s) for every scoring rule
129-
# self.heads = dict()
130-
# self.heads_flat = dict()
131-
#
132-
# for score_key in self.scores.keys():
133-
#
134-
# self.heads[score_key] = {}
135-
#
136-
# for head_key, head_config in config["heads"][score_key].items():
137-
# head = keras.Sequential()
138-
# head.build_from_config(head_config) # TODO: this method is not implemented yet
139-
# it would require the head to be a
140-
# custom object rather than a Sequential
141-
# self.heads[score_key][head_key] = head
142-
# flat_key = f"{score_key}___{head_key}"
143-
# self.heads_flat[flat_key] = head
144-
145113
def get_config(self):
146114
base_config = super().get_config()
147115

@@ -202,7 +170,7 @@ def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> dic
202170
samples = {}
203171

204172
for score_key, score in self.scores.items():
205-
if isinstance(score, ParametricDistributionRule):
173+
if isinstance(score, ParametricDistributionScore):
206174
parameters = {head_key: head(output) for head_key, head in self.heads[score_key].items()}
207175
samples[score_key] = score.sample(batch_shape, **parameters)
208176

@@ -214,7 +182,7 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> dict
214182
log_probs = {}
215183

216184
for score_key, score in self.scores.items():
217-
if isinstance(score, ParametricDistributionRule):
185+
if isinstance(score, ParametricDistributionScore):
218186
parameters = {head_key: head(output) for head_key, head in self.heads[score_key].items()}
219187
log_probs[score_key] = score.log_prob(x=samples, **parameters)
220188

0 commit comments

Comments
 (0)