1- from collections .abc import Mapping
1+ from collections .abc import Mapping , Sequence
22from ..summary_network import SummaryNetwork
33from bayesflow .utils .serialization import deserialize , serializable , serialize
44from bayesflow .types import Tensor , Shape
1010class FusionNetwork (SummaryNetwork ):
1111 def __init__ (
1212 self ,
13- backbones : Mapping [str , keras .Layer ],
13+ backbones : Sequence | Mapping [str , keras .Layer ],
1414 head : keras .Layer | None = None ,
1515 ** kwargs ,
1616 ):
17- """(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data.
17+ """(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from (optionally)
18+ multi-modal data.
1819
19- Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed
20- by the correct summary network. This means the "summary_variables" entry to the approximator has to be
21- a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method.
20+ There are two modes of operation:
21+
22+ - Identical input: each backbone receives the same input. The backbones have to be passed as a sequence.
23+ - Multi-modal input: each backbone gets its own input, which is the usual case for multi-modal data. Networks
24+ and inputs have to be passed as dictionaries with corresponding keys, so that each
25+ input is processed by the correct summary network. This means the "summary_variables" entry to the
26+ approximator has to be a dictionary, which can be achieved using the
27+ :py:meth:`bayesflow.adapters.Adapter.group` method.
2228
2329 This network implements _late_ fusion. The output of the individual summary networks is concatenated, and
2430 can be further processed by another neural network (`head`).
2531
2632 Parameters
2733 ----------
28- backbones : dict
29- A dictionary with names of inputs as keys and corresponding summary networks as values.
34+ backbones : Sequence or dict
35+ Either (see above for details):
36+
37+ - a sequence, when each backbone should receive the same input.
38+ - a dictionary with names of inputs as keys and corresponding summary networks as values.
3039 head : keras.Layer, optional
3140 A network to further process the concatenated outputs of the summary networks. By default,
3241 the concatenated outputs are returned without further processing.
@@ -37,25 +46,51 @@ def __init__(
3746 super ().__init__ (** kwargs )
3847 self .backbones = backbones
3948 self .head = head
40- self ._ordered_keys = sorted (list (self .backbones .keys ()))
49+ self ._dict_mode = isinstance (backbones , Mapping )
50+ if self ._dict_mode :
51+ # order keys to always concatenate in the same order
52+ self ._ordered_keys = sorted (list (self .backbones .keys ()))
4153
42- def build (self , inputs_shape : Mapping [str , Shape ]):
54+ def build (self , inputs_shape : Shape | Mapping [str , Shape ]):
55+ if self ._dict_mode and not isinstance (inputs_shape , Mapping ):
56+ raise ValueError (
57+ "`backbones` were passed as a dictionary, but the input shapes are not a dictionary. "
58+ "If you want to pass the same input to each backbone, pass the backbones as a list instead of a "
59+ "dictionary. If you want to provide each backbone with different input, please ensure that you have "
60+ "correctly assembled the `summary_variables` to provide a dictionary using the Adapter.group method."
61+ )
4362 if self .built :
4463 return
4564 output_shapes = []
46- for k , shape in inputs_shape .items ():
47- if not self .backbones [k ].built :
48- self .backbones [k ].build (shape )
49- output_shapes .append (self .backbones [k ].compute_output_shape (shape ))
65+ if self ._dict_mode :
66+ missing_keys = list (set (inputs_shape .keys ()).difference (set (self ._ordered_keys )))
67+ if len (missing_keys ) > 0 :
68+ raise ValueError (
69+ f"Expected the input to contain the following keys: { self ._ordered_keys } . "
70+ f"Missing keys: { missing_keys } "
71+ )
72+ for k , shape in inputs_shape .items ():
73+ # build each summary network with different input shape
74+ if not self .backbones [k ].built :
75+ self .backbones [k ].build (shape )
76+ output_shapes .append (self .backbones [k ].compute_output_shape (shape ))
77+ else :
78+ for backbone in self .backbones :
79+ # build all summary networks with the same input shape
80+ if not backbone .built :
81+ backbone .build (inputs_shape )
82+ output_shapes .append (backbone .compute_output_shape (inputs_shape ))
5083 if self .head and not self .head .built :
5184 fusion_input_shape = (* output_shapes [0 ][:- 1 ], sum (shape [- 1 ] for shape in output_shapes ))
5285 self .head .build (fusion_input_shape )
5386 self .built = True
5487
5588 def compute_output_shape (self , inputs_shape : Mapping [str , Shape ]):
5689 output_shapes = []
57- for k , shape in inputs_shape .items ():
58- output_shapes .append (self .backbones [k ].compute_output_shape (shape ))
90+ if self ._dict_mode :
91+ output_shapes = [self .backbones [k ].compute_output_shape (shape ) for k , shape in inputs_shape .items ()]
92+ else :
93+ output_shapes = [backbone .compute_output_shape (inputs_shape ) for backbone in self .backbones ]
5994 output_shape = (* output_shapes [0 ][:- 1 ], sum (shape [- 1 ] for shape in output_shapes ))
6095 if self .head :
6196 output_shape = self .head .compute_output_shape (output_shape )
@@ -65,13 +100,20 @@ def call(self, inputs: Mapping[str, Tensor], training=False):
65100 """
66101 Parameters
67102 ----------
68- inputs : dict[str, Tensor]
69- Each value in the dictionary is the input to the summary network with the corresponding key.
103+ inputs : Tensor | dict[str, Tensor]
104+ Either (see above for details):
105+
106+ - a tensor, when the backbones where passed as a list and should receive identical inputs
107+ - a dictionary, when the backbones were passed as a dictionary, where each value is the input to the
108+ summary network with the corresponding key.
70109 training : bool, optional
71110 Whether the model is in training mode, affecting layers like dropout and
72111 batch normalization. Default is False.
73112 """
74- outputs = [self .backbones [k ](inputs [k ], training = training ) for k in self ._ordered_keys ]
113+ if self ._dict_mode :
114+ outputs = [self .backbones [k ](inputs [k ], training = training ) for k in self ._ordered_keys ]
115+ else :
116+ outputs = [backbone (inputs , training = training ) for backbone in self .backbones ]
75117 outputs = ops .concatenate (outputs , axis = - 1 )
76118 if self .head is None :
77119 return outputs
@@ -81,8 +123,12 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training",
81123 """
82124 Parameters
83125 ----------
84- inputs : dict[str, Tensor]
85- Each value in the dictionary is the input to the summary network with the corresponding key.
126+ inputs : Tensor | dict[str, Tensor]
127+ Either (see above for details):
128+
129+ - a tensor, when the backbones where passed as a list and should receive identical inputs
130+ - a dictionary, when the backbones were passed as a dictionary, where each value is the input to the
131+ summary network with the corresponding key.
86132 stage : bool, optional
87133 Whether the model is in training mode, affecting layers like dropout and
88134 batch normalization. Default is False.
@@ -93,14 +139,23 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training",
93139 self .build (keras .tree .map_structure (keras .ops .shape , inputs ))
94140 metrics = {"loss" : [], "outputs" : []}
95141
96- for k in self ._ordered_keys :
97- if isinstance (self .backbones [k ], SummaryNetwork ):
98- metrics_k = self .backbones [k ].compute_metrics (inputs [k ], stage = stage , ** kwargs )
99- metrics ["outputs" ].append (metrics_k ["outputs" ])
100- if "loss" in metrics_k :
101- metrics ["loss" ].append (metrics_k ["loss" ])
142+ def process_backbone (backbone , input ):
143+ # helper function to avoid code duplication for the two modes
144+ if isinstance (backbone , SummaryNetwork ):
145+ backbone_metrics = backbone .compute_metrics (input , stage = stage , ** kwargs )
146+ metrics ["outputs" ].append (backbone_metrics ["outputs" ])
147+ if "loss" in backbone_metrics :
148+ metrics ["loss" ].append (backbone_metrics ["loss" ])
102149 else :
103- metrics ["outputs" ].append (self .backbones [k ](inputs [k ], training = stage == "training" ))
150+ metrics ["outputs" ].append (backbone (input , training = stage == "training" ))
151+
152+ if self ._dict_mode :
153+ for k in self ._ordered_keys :
154+ process_backbone (self .backbones [k ], inputs [k ])
155+ else :
156+ for backbone in self .backbones :
157+ process_backbone (backbone , inputs )
158+
104159 if len (metrics ["loss" ]) == 0 :
105160 del metrics ["loss" ]
106161 else :
0 commit comments