Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Drop,
ExpandDims,
FilterTransform,
Group,
Keep,
Log,
MapTransform,
Expand All @@ -25,6 +26,7 @@
Standardize,
ToArray,
Transform,
Ungroup,
RandomSubsample,
Take,
)
Expand Down Expand Up @@ -600,6 +602,52 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
self.transforms.append(transform)
return self

def group(self, keys: Sequence[str], into: str, *, prefix: str = ""):
"""Append a :py:class:`~transforms.Group` transform to the adapter.

Groups the given variables as a dictionary in the key `into`. As most transforms do
not support nested structures, this should usually be the last transform in the adapter.

Parameters
----------
keys : Sequence of str
The names of the variables to group together.
into : str
The name of the variable to store the grouped variables in.
prefix : str, optional
An optional common prefix of the variable names before grouping, which will be removed after grouping.

Raises
------
ValueError
If a prefix is specified, but a provided key does not start with the prefix.
"""
if isinstance(keys, str):
keys = [keys]

transform = Group(keys=keys, into=into, prefix=prefix)
self.transforms.append(transform)
return self

def ungroup(self, key: str, *, prefix: str = ""):
"""Append an :py:class:`~transforms.Ungroup` transform to the adapter.

Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do
not support nested structures, so this can be used to flatten a nested structure.
The nesting can be re-established after the transforms using the :py:meth:`group` method.

Parameters
----------
key : str
The name of the variable to ungroup. The corresponding variable has to be a dictionary.
prefix : str, optional
An optional common prefix that will be added to the ungrouped variable names. This can be necessary
to avoid duplicate names.
"""
transform = Ungroup(key=key, prefix=prefix)
self.transforms.append(transform)
return self

def keep(self, keys: str | Sequence[str]):
"""Append a :py:class:`~transforms.Keep` transform to the adapter.

Expand Down
2 changes: 2 additions & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .elementwise_transform import ElementwiseTransform
from .expand_dims import ExpandDims
from .filter_transform import FilterTransform
from .group import Group
from .keep import Keep
from .log import Log
from .map_transform import MapTransform
Expand All @@ -25,6 +26,7 @@
from .transform import Transform
from .random_subsample import RandomSubsample
from .take import Take
from .ungroup import Ungroup

from ...utils._docs import _add_imports_to_all

Expand Down
81 changes: 81 additions & 0 deletions bayesflow/adapters/transforms/group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from collections.abc import Sequence
from .transform import Transform
from bayesflow.utils.serialization import serializable, serialize


@serializable("bayesflow.adapters")
class Group(Transform):
def __init__(self, keys: Sequence[str], into: str, prefix: str = ""):
"""Groups the given variables as a dictionary in the key `into`. As most transforms do
not support nested structures, this should usually be the last transform.

Parameters
----------
keys : Sequence of str
The names of the variables to group together.
into : str
The name of the variable to store the grouped variables in.
prefix : str, optional
A common prefix of the ungrouped variable names, which will be removed after grouping.

Raises
------
ValueError
If a prefix is specified, but a provided key does not start with the prefix.
"""
super().__init__()
self.keys = keys
self.into = into
self.prefix = prefix
for key in keys:
if not key.startswith(prefix):
raise ValueError(f"If prefix is specified, all keys have to start with prefix. Found '{key}'.")

def get_config(self) -> dict:
return serialize({"keys": self.keys, "into": self.into, "prefix": self.prefix})

def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
data = data.copy()

data[self.into] = data.get(self.into, {})
for key in self.keys:
if key not in data:
if strict:
raise KeyError(f"Missing key: {key!r}")
else:
data[self.into][key[len(self.prefix) :]] = data.pop(key)

return data

def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]:
data = data.copy()

if strict and self.into not in data:
raise KeyError(f"Missing key: {self.into!r}")
elif self.into not in data:
return data

for key in self.keys:
internal_key = key[len(self.prefix) :]
if internal_key not in data[self.into]:
if strict:
raise KeyError(f"Missing key: {internal_key!r}")
else:
data[key] = data[self.into].pop(internal_key)

if len(data[self.into]) == 0:
del data[self.into]

return data

def extra_repr(self) -> str:
return f"{self.keys!r} -> {self.into!r}"

def log_det_jac(
self,
data: dict[str, any],
log_det_jac: dict[str, any],
inverse: bool = False,
**kwargs,
):
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)
83 changes: 83 additions & 0 deletions bayesflow/adapters/transforms/ungroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from .transform import Transform
from bayesflow.utils.serialization import deserialize, serializable, serialize


@serializable("bayesflow.adapters")
class Ungroup(Transform):
def __init__(self, key: str, prefix: str = ""):
"""
Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do
not support nested structures, so this can be used to flatten a nested structure.
It can later on be reassembled using the :py:class:`bayesflow.adapters.transforms.Group` transform.

Parameters
----------
key : str
The name of the variable to ungroup. The variable has to be a dictionary.
prefix : str, optional
An optional common prefix that will be added to the ungrouped variable names. This can be necessary
to avoid duplicate names.
"""
super().__init__()
self.key = key
self.prefix = prefix
self._ungrouped = None

def get_config(self) -> dict:
return serialize({"key": self.key, "prefix": self.prefix, "_ungrouped": self._ungrouped})

@classmethod
def from_config(cls, config: dict, custom_objects=None):
config = deserialize(config, custom_objects)
_ungrouped = config.pop("_ungrouped")
transform = cls(**config)
transform._ungrouped = _ungrouped
return transform

def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
data = data.copy()

if self.key not in data and strict:
raise KeyError(f"Missing key: {self.key!r}")
elif self.key not in data:
return data

ungrouped = []
for k, v in data.pop(self.key).items():
new_key = f"{self.prefix}{k}"
if new_key in data:
raise ValueError(
f"Encountered duplicate key during ungrouping: '{new_key}'."
" Use `prefix` to specify a unique prefix that is added to the key"
)
ungrouped.append(new_key)
data[new_key] = v
if self._ungrouped is None:
self._ungrouped = sorted(ungrouped)
else:
self._ungrouped = sorted(list(set(self._ungrouped + ungrouped)))

return data

def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]:
data = data.copy()

data[self.key] = {}
for key in self._ungrouped:
if key not in data:
if strict:
raise KeyError(f"Missing key: {key!r}")
else:
recovered_key = key[len(self.prefix) :]
data[self.key][recovered_key] = data.pop(key)

return data

def log_det_jac(
self,
data: dict[str, any],
log_det_jac: dict[str, any],
inverse: bool = False,
**kwargs,
):
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)
1 change: 1 addition & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .inference_network import InferenceNetwork
from .point_inference_network import PointInferenceNetwork
from .mlp import MLP
from .fusion_network import FusionNetwork
from .summary_network import SummaryNetwork
from .time_series_network import TimeSeriesNetwork
from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/fusion_network/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fusion_network import FusionNetwork
123 changes: 123 additions & 0 deletions bayesflow/networks/fusion_network/fusion_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from collections.abc import Mapping
from ..summary_network import SummaryNetwork
from bayesflow.utils.serialization import deserialize, serializable, serialize
from bayesflow.types import Tensor, Shape
import keras
from keras import ops


@serializable("bayesflow.networks")
class FusionNetwork(SummaryNetwork):
def __init__(
self,
backbones: Mapping[str, keras.Layer],
head: keras.Layer | None = None,
**kwargs,
):
"""(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data.

Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed
by the correct summary network. This means the "summary_variables" entry to the approximator has to be
a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method.

This network implements _late_ fusion. The output of the individual summary networks is concatenated, and
can be further processed by another neural network (`head`).

Parameters
----------
backbones : dict
A dictionary with names of inputs as keys and corresponding summary networks as values.
head : keras.Layer, optional
A network to further process the concatenated outputs of the summary networks. By default,
the concatenated outputs are returned without further processing.
**kwargs
Additional keyword arguments that are passed to the :py:class:`~bayesflow.networks.SummaryNetwork`
base class.
"""
super().__init__(**kwargs)
self.backbones = backbones
self.head = head
self._ordered_keys = sorted(list(self.backbones.keys()))

def build(self, inputs_shape: Mapping[str, Shape]):
if self.built:
return
output_shapes = []
for k, shape in inputs_shape.items():
if not self.backbones[k].built:
self.backbones[k].build(shape)
output_shapes.append(self.backbones[k].compute_output_shape(shape))
if self.head and not self.head.built:
fusion_input_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
self.head.build(fusion_input_shape)
self.built = True

def compute_output_shape(self, inputs_shape: Mapping[str, Shape]):
output_shapes = []
for k, shape in inputs_shape.items():
output_shapes.append(self.backbones[k].compute_output_shape(shape))
output_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
if self.head:
output_shape = self.head.compute_output_shape(output_shape)
return output_shape

def call(self, inputs: Mapping[str, Tensor], training=False):
"""
Parameters
----------
inputs : dict[str, Tensor]
Each value in the dictionary is the input to the summary network with the corresponding key.
training : bool, optional
Whether the model is in training mode, affecting layers like dropout and
batch normalization. Default is False.
"""
outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys]
outputs = ops.concatenate(outputs, axis=-1)
if self.head is None:
return outputs
return self.head(outputs, training=training)

def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training", **kwargs) -> dict[str, Tensor]:
"""
Parameters
----------
inputs : dict[str, Tensor]
Each value in the dictionary is the input to the summary network with the corresponding key.
stage : bool, optional
Whether the model is in training mode, affecting layers like dropout and
batch normalization. Default is False.
**kwargs
Additional keyword arguments.
"""
metrics = {"loss": [], "outputs": []}

for k in self._ordered_keys:
if isinstance(self.backbones[k], SummaryNetwork):
metrics_k = self.backbones[k].compute_metrics(inputs[k], stage=stage, **kwargs)
metrics["outputs"].append(metrics_k["outputs"])
if "loss" in metrics_k:
metrics["loss"].append(metrics_k["loss"])
else:
metrics["outputs"].append(self.backbones[k](inputs[k], training=stage == "training"))
if len(metrics["loss"]) == 0:
del metrics["loss"]
else:
metrics["loss"] = ops.sum(metrics["loss"])
metrics["outputs"] = ops.concatenate(metrics["outputs"], axis=-1)
if self.head is not None:
metrics["outputs"] = self.head(metrics["outputs"], training=stage == "training")

return metrics

def get_config(self) -> dict:
base_config = super().get_config()
config = {
"backbones": self.backbones,
"head": self.head,
}
return base_config | serialize(config)

@classmethod
def from_config(cls, config: dict, custom_objects=None):
config = deserialize(config, custom_objects=custom_objects)
return cls(**config)
Loading