Skip to content

Commit e2bd0e2

Browse files
committed
Fix #506 (missing logits_projector serialization)
1 parent 735969c commit e2bd0e2

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ class ModelComparisonApproximator(Approximator):
3333
summary_network: bf.networks.SummaryNetwork, optional
3434
The summary network used for data summarization (default is None).
3535
The input of the summary network is `summary_variables`.
36+
logits_projector: keras.layers.Layer, optional
37+
A layer that projects the output of the classifier network to the logits space for each model.
38+
If not provided, a dense layer with `num_models` units is used by default.
3639
"""
3740

3841
SAMPLE_KEYS = ["summary_variables", "classifier_conditions"]
@@ -44,14 +47,15 @@ def __init__(
4447
classifier_network: keras.Layer,
4548
adapter: Adapter,
4649
summary_network: SummaryNetwork = None,
50+
logits_projector: keras.layers.Layer = None,
4751
**kwargs,
4852
):
4953
super().__init__(**kwargs)
5054
self.classifier_network = classifier_network
5155
self.adapter = adapter
5256
self.summary_network = summary_network
5357
self.num_models = num_models
54-
self.logits_projector = keras.layers.Dense(num_models)
58+
self.logits_projector = logits_projector if logits_projector is not None else keras.layers.Dense(num_models)
5559

5660
def build(self, data_shapes: Mapping[str, Shape]):
5761
data = {key: keras.ops.zeros(value) for key, value in data_shapes.items()}
@@ -266,6 +270,7 @@ def get_config(self):
266270
"adapter": self.adapter,
267271
"classifier_network": self.classifier_network,
268272
"summary_network": self.summary_network,
273+
"logits_projector": self.logits_projector,
269274
}
270275

271276
return base_config | serialize(config)

0 commit comments

Comments
 (0)