@@ -33,6 +33,9 @@ class ModelComparisonApproximator(Approximator):
3333 summary_network: bf.networks.SummaryNetwork, optional
3434 The summary network used for data summarization (default is None).
3535 The input of the summary network is `summary_variables`.
36+ logits_projector: keras.layers.Layer, optional
37+ A layer that projects the output of the classifier network to the logits space for each model.
38+ If not provided, a dense layer with `num_models` units is used by default.
3639 """
3740
3841 SAMPLE_KEYS = ["summary_variables" , "classifier_conditions" ]
@@ -44,14 +47,15 @@ def __init__(
4447 classifier_network : keras .Layer ,
4548 adapter : Adapter ,
4649 summary_network : SummaryNetwork = None ,
50+ logits_projector : keras .layers .Layer = None ,
4751 ** kwargs ,
4852 ):
4953 super ().__init__ (** kwargs )
5054 self .classifier_network = classifier_network
5155 self .adapter = adapter
5256 self .summary_network = summary_network
5357 self .num_models = num_models
54- self .logits_projector = keras .layers .Dense (num_models )
58+ self .logits_projector = logits_projector if logits_projector is not None else keras .layers .Dense (num_models )
5559
5660 def build (self , data_shapes : Mapping [str , Shape ]):
5761 data = {key : keras .ops .zeros (value ) for key , value in data_shapes .items ()}
@@ -266,6 +270,7 @@ def get_config(self):
266270 "adapter" : self .adapter ,
267271 "classifier_network" : self .classifier_network ,
268272 "summary_network" : self .summary_network ,
273+ "logits_projector" : self .logits_projector ,
269274 }
270275
271276 return base_config | serialize (config )
0 commit comments