Skip to content

Commit 63a4a4e

Browse files
committed
FusionNetwork: handle multiple networks with identical inputs
Up to now, the network strictly required passing the input to each summary network, which required duplication of the data somewhere upstream if the same data should be used for all backbones, which might become a more common use case. For this reason, I extended the fusion network with a second mode where all backbones receive the same input. This is one possible implementation, but we might also outsource this functionality into a separate class.
1 parent 6914baf commit 63a4a4e

File tree

3 files changed

+143
-54
lines changed

3 files changed

+143
-54
lines changed

bayesflow/networks/fusion_network/fusion_network.py

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Mapping
1+
from collections.abc import Mapping, Sequence
22
from ..summary_network import SummaryNetwork
33
from bayesflow.utils.serialization import deserialize, serializable, serialize
44
from bayesflow.types import Tensor, Shape
@@ -10,23 +10,32 @@
1010
class FusionNetwork(SummaryNetwork):
1111
def __init__(
1212
self,
13-
backbones: Mapping[str, keras.Layer],
13+
backbones: Sequence | Mapping[str, keras.Layer],
1414
head: keras.Layer | None = None,
1515
**kwargs,
1616
):
17-
"""(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data.
17+
"""(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from (optionally)
18+
multi-modal data.
1819
19-
Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed
20-
by the correct summary network. This means the "summary_variables" entry to the approximator has to be
21-
a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method.
20+
There are two modes of operation:
21+
22+
- Identical input: each backbone receives the same input. The backbones have to be passed as a sequence.
23+
- Multi-modal input: each backbone gets its own input, which is the usual case for multi-modal data. Networks
24+
and inputs have to be passed as dictionaries with corresponding keys, so that each
25+
input is processed by the correct summary network. This means the "summary_variables" entry to the
26+
approximator has to be a dictionary, which can be achieved using the
27+
:py:meth:`bayesflow.adapters.Adapter.group` method.
2228
2329
This network implements _late_ fusion. The output of the individual summary networks is concatenated, and
2430
can be further processed by another neural network (`head`).
2531
2632
Parameters
2733
----------
28-
backbones : dict
29-
A dictionary with names of inputs as keys and corresponding summary networks as values.
34+
backbones : Sequence or dict
35+
Either (see above for details):
36+
37+
- a sequence, when each backbone should receive the same input.
38+
- a dictionary with names of inputs as keys and corresponding summary networks as values.
3039
head : keras.Layer, optional
3140
A network to further process the concatenated outputs of the summary networks. By default,
3241
the concatenated outputs are returned without further processing.
@@ -37,25 +46,51 @@ def __init__(
3746
super().__init__(**kwargs)
3847
self.backbones = backbones
3948
self.head = head
40-
self._ordered_keys = sorted(list(self.backbones.keys()))
49+
self._dict_mode = isinstance(backbones, Mapping)
50+
if self._dict_mode:
51+
# order keys to always concatenate in the same order
52+
self._ordered_keys = sorted(list(self.backbones.keys()))
4153

42-
def build(self, inputs_shape: Mapping[str, Shape]):
54+
def build(self, inputs_shape: Shape | Mapping[str, Shape]):
55+
if self._dict_mode and not isinstance(inputs_shape, Mapping):
56+
raise ValueError(
57+
"`backbones` were passed as a dictionary, but the input shapes are not a dictionary. "
58+
"If you want to pass the same input to each backbone, pass the backbones as a list instead of a "
59+
"dictionary. If you want to provide each backbone with different input, please ensure that you have "
60+
"correctly assembled the `summary_variables` to provide a dictionary using the Adapter.group method."
61+
)
4362
if self.built:
4463
return
4564
output_shapes = []
46-
for k, shape in inputs_shape.items():
47-
if not self.backbones[k].built:
48-
self.backbones[k].build(shape)
49-
output_shapes.append(self.backbones[k].compute_output_shape(shape))
65+
if self._dict_mode:
66+
missing_keys = list(set(inputs_shape.keys()).difference(set(self._ordered_keys)))
67+
if len(missing_keys) > 0:
68+
raise ValueError(
69+
f"Expected the input to contain the following keys: {self._ordered_keys}. "
70+
f"Missing keys: {missing_keys}"
71+
)
72+
for k, shape in inputs_shape.items():
73+
# build each summary network with different input shape
74+
if not self.backbones[k].built:
75+
self.backbones[k].build(shape)
76+
output_shapes.append(self.backbones[k].compute_output_shape(shape))
77+
else:
78+
for backbone in self.backbones:
79+
# build all summary networks with the same input shape
80+
if not backbone.built:
81+
backbone.build(inputs_shape)
82+
output_shapes.append(backbone.compute_output_shape(inputs_shape))
5083
if self.head and not self.head.built:
5184
fusion_input_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
5285
self.head.build(fusion_input_shape)
5386
self.built = True
5487

5588
def compute_output_shape(self, inputs_shape: Mapping[str, Shape]):
5689
output_shapes = []
57-
for k, shape in inputs_shape.items():
58-
output_shapes.append(self.backbones[k].compute_output_shape(shape))
90+
if self._dict_mode:
91+
output_shapes = [self.backbones[k].compute_output_shape(shape) for k, shape in inputs_shape.items()]
92+
else:
93+
output_shapes = [backbone.compute_output_shape(inputs_shape) for backbone in self.backbones]
5994
output_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
6095
if self.head:
6196
output_shape = self.head.compute_output_shape(output_shape)
@@ -65,13 +100,20 @@ def call(self, inputs: Mapping[str, Tensor], training=False):
65100
"""
66101
Parameters
67102
----------
68-
inputs : dict[str, Tensor]
69-
Each value in the dictionary is the input to the summary network with the corresponding key.
103+
inputs : Tensor | dict[str, Tensor]
104+
Either (see above for details):
105+
106+
- a tensor, when the backbones where passed as a list and should receive identical inputs
107+
- a dictionary, when the backbones were passed as a dictionary, where each value is the input to the
108+
summary network with the corresponding key.
70109
training : bool, optional
71110
Whether the model is in training mode, affecting layers like dropout and
72111
batch normalization. Default is False.
73112
"""
74-
outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys]
113+
if self._dict_mode:
114+
outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys]
115+
else:
116+
outputs = [backbone(inputs, training=training) for backbone in self.backbones]
75117
outputs = ops.concatenate(outputs, axis=-1)
76118
if self.head is None:
77119
return outputs
@@ -81,8 +123,12 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training",
81123
"""
82124
Parameters
83125
----------
84-
inputs : dict[str, Tensor]
85-
Each value in the dictionary is the input to the summary network with the corresponding key.
126+
inputs : Tensor | dict[str, Tensor]
127+
Either (see above for details):
128+
129+
- a tensor, when the backbones where passed as a list and should receive identical inputs
130+
- a dictionary, when the backbones were passed as a dictionary, where each value is the input to the
131+
summary network with the corresponding key.
86132
stage : bool, optional
87133
Whether the model is in training mode, affecting layers like dropout and
88134
batch normalization. Default is False.
@@ -93,14 +139,23 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training",
93139
self.build(keras.tree.map_structure(keras.ops.shape, inputs))
94140
metrics = {"loss": [], "outputs": []}
95141

96-
for k in self._ordered_keys:
97-
if isinstance(self.backbones[k], SummaryNetwork):
98-
metrics_k = self.backbones[k].compute_metrics(inputs[k], stage=stage, **kwargs)
99-
metrics["outputs"].append(metrics_k["outputs"])
100-
if "loss" in metrics_k:
101-
metrics["loss"].append(metrics_k["loss"])
142+
def process_backbone(backbone, input):
143+
# helper function to avoid code duplication for the two modes
144+
if isinstance(backbone, SummaryNetwork):
145+
backbone_metrics = backbone.compute_metrics(input, stage=stage, **kwargs)
146+
metrics["outputs"].append(backbone_metrics["outputs"])
147+
if "loss" in backbone_metrics:
148+
metrics["loss"].append(backbone_metrics["loss"])
102149
else:
103-
metrics["outputs"].append(self.backbones[k](inputs[k], training=stage == "training"))
150+
metrics["outputs"].append(backbone(input, training=stage == "training"))
151+
152+
if self._dict_mode:
153+
for k in self._ordered_keys:
154+
process_backbone(self.backbones[k], inputs[k])
155+
else:
156+
for backbone in self.backbones:
157+
process_backbone(backbone, inputs)
158+
104159
if len(metrics["loss"]) == 0:
105160
del metrics["loss"]
106161
else:
Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,40 @@
11
import pytest
22

33

4+
@pytest.fixture(params=[True, False])
5+
def multimodal(request):
6+
return request.param
7+
8+
49
@pytest.fixture()
5-
def multimodal_data(random_samples, random_set):
6-
return {"x1": random_samples, "x2": random_set}
10+
def data(random_samples, random_set, multimodal):
11+
if multimodal:
12+
return {"x1": random_samples, "x2": random_set}
13+
return random_set
714

815

916
@pytest.fixture()
10-
def fusion_network():
17+
def fusion_network(multimodal):
1118
from bayesflow.networks import FusionNetwork, DeepSet
1219
import keras
1320

21+
deepset_kwargs = dict(
22+
summary_dim=2,
23+
mlp_widths_equivariant=(2, 2),
24+
mlp_widths_invariant_inner=(2, 2),
25+
mlp_widths_invariant_outer=(2, 2),
26+
mlp_widths_invariant_last=(2, 2),
27+
base_distribution="normal",
28+
)
29+
if multimodal:
30+
return FusionNetwork(
31+
backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(**deepset_kwargs)},
32+
head=keras.layers.Dense(3),
33+
)
1434
return FusionNetwork(
15-
backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(summary_dim=2, base_distribution="normal")},
35+
backbones=[
36+
DeepSet(**deepset_kwargs),
37+
DeepSet(**deepset_kwargs),
38+
],
1639
head=keras.layers.Dense(3),
1740
)

tests/test_networks/test_fusion_network/test_fusion_network.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,52 +6,65 @@
66

77

88
@pytest.mark.parametrize("automatic", [True, False])
9-
def test_build(automatic, fusion_network, multimodal_data):
9+
def test_build(automatic, fusion_network, data, multimodal):
1010
if fusion_network is None:
1111
pytest.skip(reason="Nothing to do, because there is no summary network.")
1212

1313
assert fusion_network.built is False
1414

1515
if automatic:
16-
fusion_network(multimodal_data)
16+
fusion_network(data)
1717
else:
18-
fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
18+
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))
1919

2020
assert fusion_network.built is True
2121

2222
# check the model has variables
2323
assert fusion_network.variables, "Model has no variables."
2424

2525

26+
def test_build_failure(fusion_network, data, multimodal):
27+
if not multimodal:
28+
pytest.skip(reason="Nothing to do, as summary networks may consume aribrary inputs")
29+
with pytest.raises(ValueError):
30+
fusion_network.build((3, 2, 2))
31+
with pytest.raises(ValueError):
32+
data["x3"] = data.pop("x1")
33+
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))
34+
35+
2636
@pytest.mark.parametrize("automatic", [True, False])
27-
def test_build_functional_api(automatic, fusion_network, multimodal_data):
37+
def test_build_functional_api(automatic, fusion_network, data, multimodal):
2838
if fusion_network is None:
2939
pytest.skip(reason="Nothing to do, because there is no summary network.")
3040

3141
assert fusion_network.built is False
3242

33-
inputs = {}
34-
for k, v in multimodal_data.items():
35-
inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k)
43+
if multimodal:
44+
inputs = {}
45+
for k, v in data.items():
46+
inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k)
47+
else:
48+
inputs = keras.layers.Input(shape=keras.ops.shape(data)[1:])
3649
outputs = fusion_network(inputs)
3750
model = keras.Model(inputs=inputs, outputs=outputs)
3851

3952
if automatic:
40-
model(multimodal_data)
53+
model(data)
4154
else:
42-
model.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
55+
model.build(keras.tree.map_structure(keras.ops.shape, data))
4356

4457
assert model.built is True
4558

4659
# check the model has variables
4760
assert fusion_network.variables, "Model has no variables."
4861

4962

50-
def test_serialize_deserialize(fusion_network, multimodal_data):
63+
def test_serialize_deserialize(fusion_network, data, multimodal):
5164
if fusion_network is None:
5265
pytest.skip(reason="Nothing to do, because there is no summary network.")
5366

54-
fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
67+
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))
5568

5669
serialized = serialize(fusion_network)
5770
deserialized = deserialize(serialized)
@@ -60,28 +73,28 @@ def test_serialize_deserialize(fusion_network, multimodal_data):
6073
assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)
6174

6275

63-
def test_save_and_load(tmp_path, fusion_network, multimodal_data):
76+
def test_save_and_load(tmp_path, fusion_network, data, multimodal):
6477
if fusion_network is None:
6578
pytest.skip(reason="Nothing to do, because there is no summary network.")
6679

67-
fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
80+
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))
6881

6982
keras.saving.save_model(fusion_network, tmp_path / "model.keras")
7083
loaded = keras.saving.load_model(tmp_path / "model.keras")
7184

7285
assert_layers_equal(fusion_network, loaded)
73-
assert allclose(fusion_network(multimodal_data), loaded(multimodal_data))
86+
assert allclose(fusion_network(data), loaded(data))
7487

7588

7689
@pytest.mark.parametrize("stage", ["training", "validation"])
77-
def test_compute_metrics(stage, fusion_network, multimodal_data):
90+
def test_compute_metrics(stage, fusion_network, data, multimodal):
7891
if fusion_network is None:
7992
pytest.skip("Nothing to do, because there is no summary network.")
8093

81-
fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
94+
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))
8295

83-
metrics = fusion_network.compute_metrics(multimodal_data, stage=stage)
84-
outputs_via_call = fusion_network(multimodal_data, training=stage == "training")
96+
metrics = fusion_network.compute_metrics(data, stage=stage)
97+
outputs_via_call = fusion_network(data, training=stage == "training")
8598

8699
assert "outputs" in metrics
87100

@@ -90,11 +103,9 @@ def test_compute_metrics(stage, fusion_network, multimodal_data):
90103
assert allclose(metrics["outputs"], outputs_via_call)
91104

92105
# check that the batch dimension is preserved
93-
assert (
94-
keras.ops.shape(metrics["outputs"])[0]
95-
== keras.ops.shape(multimodal_data[next(iter(multimodal_data.keys()))])[0]
96-
)
106+
batch_size = keras.ops.shape(data)[0] if not multimodal else keras.ops.shape(data[next(iter(data.keys()))])[0]
97107

108+
assert keras.ops.shape(metrics["outputs"])[0] == batch_size
98109
assert "loss" in metrics
99110
assert keras.ops.shape(metrics["loss"]) == ()
100111

0 commit comments

Comments
 (0)