@@ -25,16 +25,16 @@ class ModelComparisonApproximator(Approximator):
2525
2626 Parameters
2727 ----------
28- adapter: Adapter
29- Adapter for data processing.
28+ adapter: bf.adapters. Adapter
29+ Adapter for data pre- processing.
3030 num_models: int
3131 Number of models (simulators) that the approximator will compare
32- classifier_network: keras.Model
33- The network (e.g, an MLP) that is used for model classification.
32+ classifier_network: keras.Layer
33+ The network backbone (e.g, an MLP) that is used for model classification.
3434 The input of the classifier network is created by concatenating `classifier_variables`
3535 and (optional) output of the summary_network.
36- summary_network: SummaryNetwork, optional
37- The summary network used for data summarisation (default is None).
36+ summary_network: bg.networks. SummaryNetwork, optional
37+ The summary network used for data summarization (default is None).
3838 The input of the summary network is `summary_variables`.
3939 """
4040
@@ -51,7 +51,7 @@ def __init__(
5151 self .classifier_network = classifier_network
5252 self .adapter = adapter
5353 self .summary_network = summary_network
54-
54+ self . num_models = num_models
5555 self .logits_projector = keras .layers .Dense (num_models )
5656
5757 def build (self , data_shapes : Mapping [str , Shape ]):
@@ -61,6 +61,7 @@ def build(self, data_shapes: Mapping[str, Shape]):
6161 @classmethod
6262 def build_adapter (
6363 cls ,
64+ num_models : int ,
6465 classifier_conditions : Sequence [str ] = None ,
6566 summary_variables : Sequence [str ] = None ,
6667 model_index_name : str = "model_indices" ,
@@ -80,11 +81,9 @@ def build_adapter(
8081 adapter .rename (model_index_name , "model_indices" )
8182 .keep (["classifier_conditions" , "summary_variables" , "model_indices" ])
8283 .standardize (exclude = "model_indices" )
84+ .one_hot ("model_indices" , num_models )
8385 )
8486
85- # TODO: add one-hot encoding
86- # .one_hot("model_indices", self.num_models)
87-
8887 return adapter
8988
9089 @classmethod
@@ -239,7 +238,7 @@ def fit(
239238
240239 if adapter == "auto" :
241240 logging .info ("Building automatic data adapter." )
242- adapter = self .build_adapter (** filter_kwargs (kwargs , self .build_adapter ))
241+ adapter = self .build_adapter (num_models = self . num_models , ** filter_kwargs (kwargs , self .build_adapter ))
243242
244243 if simulator is not None :
245244 return super ().fit (simulator = simulator , adapter = adapter , ** kwargs )
@@ -252,15 +251,17 @@ def fit(
252251
253252 @classmethod
254253 def from_config (cls , config , custom_objects = None ):
255- adapter = deserialize (config ["adapter" ], custom_objects = custom_objects )
256- classifier_network = deserialize (config ["classifier_network" ], custom_objects = custom_objects )
257- summary_network = deserialize (config ["summary_network" ], custom_objects = custom_objects )
258- return cls (adapter = adapter , classifier_network = classifier_network , summary_network = summary_network , ** config )
254+ config ["num_models" ] = deserialize (config ["num_models" ], custom_objects = custom_objects )
255+ config ["adapter" ] = deserialize (config ["adapter" ], custom_objects = custom_objects )
256+ (config ["classifier_network" ],) = deserialize (config ["classifier_network" ], custom_objects = custom_objects )
257+ config ["summary_network" ] = deserialize (config ["summary_network" ], custom_objects = custom_objects )
258+ return super ().from_config (config , custom_objects = custom_objects )
259259
260260 def get_config (self ):
261261 base_config = super ().get_config ()
262262
263263 config = {
264+ "num_models" : serialize (self .num_models ),
264265 "adapter" : serialize (self .adapter ),
265266 "classifier_network" : serialize (self .classifier_network ),
266267 "summary_network" : serialize (self .summary_network ),
0 commit comments