Skip to content

Commit 635071e

Browse files
committed
Modify model comparison approximator
1 parent 9c88fb9 commit 635071e

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ class ModelComparisonApproximator(Approximator):
2525
2626
Parameters
2727
----------
28-
adapter: Adapter
29-
Adapter for data processing.
28+
adapter: bf.adapters.Adapter
29+
Adapter for data pre-processing.
3030
num_models: int
3131
Number of models (simulators) that the approximator will compare
32-
classifier_network: keras.Model
33-
The network (e.g, an MLP) that is used for model classification.
32+
classifier_network: keras.Layer
33+
The network backbone (e.g, an MLP) that is used for model classification.
3434
The input of the classifier network is created by concatenating `classifier_variables`
3535
and (optional) output of the summary_network.
36-
summary_network: SummaryNetwork, optional
37-
The summary network used for data summarisation (default is None).
36+
summary_network: bg.networks.SummaryNetwork, optional
37+
The summary network used for data summarization (default is None).
3838
The input of the summary network is `summary_variables`.
3939
"""
4040

@@ -51,7 +51,7 @@ def __init__(
5151
self.classifier_network = classifier_network
5252
self.adapter = adapter
5353
self.summary_network = summary_network
54-
54+
self.num_models = num_models
5555
self.logits_projector = keras.layers.Dense(num_models)
5656

5757
def build(self, data_shapes: Mapping[str, Shape]):
@@ -61,6 +61,7 @@ def build(self, data_shapes: Mapping[str, Shape]):
6161
@classmethod
6262
def build_adapter(
6363
cls,
64+
num_models: int,
6465
classifier_conditions: Sequence[str] = None,
6566
summary_variables: Sequence[str] = None,
6667
model_index_name: str = "model_indices",
@@ -80,11 +81,9 @@ def build_adapter(
8081
adapter.rename(model_index_name, "model_indices")
8182
.keep(["classifier_conditions", "summary_variables", "model_indices"])
8283
.standardize(exclude="model_indices")
84+
.one_hot("model_indices", num_models)
8385
)
8486

85-
# TODO: add one-hot encoding
86-
# .one_hot("model_indices", self.num_models)
87-
8887
return adapter
8988

9089
@classmethod
@@ -239,7 +238,7 @@ def fit(
239238

240239
if adapter == "auto":
241240
logging.info("Building automatic data adapter.")
242-
adapter = self.build_adapter(**filter_kwargs(kwargs, self.build_adapter))
241+
adapter = self.build_adapter(num_models=self.num_models, **filter_kwargs(kwargs, self.build_adapter))
243242

244243
if simulator is not None:
245244
return super().fit(simulator=simulator, adapter=adapter, **kwargs)
@@ -252,15 +251,17 @@ def fit(
252251

253252
@classmethod
254253
def from_config(cls, config, custom_objects=None):
255-
adapter = deserialize(config["adapter"], custom_objects=custom_objects)
256-
classifier_network = deserialize(config["classifier_network"], custom_objects=custom_objects)
257-
summary_network = deserialize(config["summary_network"], custom_objects=custom_objects)
258-
return cls(adapter=adapter, classifier_network=classifier_network, summary_network=summary_network, **config)
254+
config["num_models"] = deserialize(config["num_models"], custom_objects=custom_objects)
255+
config["adapter"] = deserialize(config["adapter"], custom_objects=custom_objects)
256+
(config["classifier_network"],) = deserialize(config["classifier_network"], custom_objects=custom_objects)
257+
config["summary_network"] = deserialize(config["summary_network"], custom_objects=custom_objects)
258+
return super().from_config(config, custom_objects=custom_objects)
259259

260260
def get_config(self):
261261
base_config = super().get_config()
262262

263263
config = {
264+
"num_models": serialize(self.num_models),
264265
"adapter": serialize(self.adapter),
265266
"classifier_network": serialize(self.classifier_network),
266267
"summary_network": serialize(self.summary_network),

0 commit comments

Comments
 (0)