Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

from bayesflow.networks import MLP
from bayesflow.types import Tensor
from bayesflow.utils import (
logging,
jvp,
concatenate_valid,
find_network,
expand_right_as,
expand_right_to,
layer_kwargs,
)
from bayesflow.utils import logging, jvp, find_network, expand_right_as, expand_right_to, layer_kwargs, tensor_utils
from bayesflow.utils.serialization import deserialize, serializable, serialize

from bayesflow.networks import InferenceNetwork
Expand Down Expand Up @@ -83,6 +75,11 @@ def __init__(
includes depth, hidden sizes, and non-linearity choices.
embedding_kwargs : dict[str, any], optional, default=None
Keyword arguments for the time embedding layer(s) used in the model
concatenate_subnet_input: bool, optional
Flag for advanced users to control whether all inputs to the subnet should be concatenated
into a single vector or passed as separate arguments. If set to False, the subnet
must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
and optional 'conditions'. Default is True.
**kwargs
Additional keyword arguments passed to the parent ``InferenceNetwork`` initializer
(e.g., ``name``, ``dtype``, or ``trainable``).
Expand All @@ -97,6 +94,7 @@ def __init__(
self.subnet_projector = keras.layers.Dense(
units=None, bias_initializer="zeros", kernel_initializer="zeros", name="subnet_projector"
)
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

weight_mlp_kwargs = weight_mlp_kwargs or {}
weight_mlp_kwargs = StableConsistencyModel.WEIGHT_MLP_DEFAULT_CONFIG | weight_mlp_kwargs
Expand All @@ -107,6 +105,7 @@ def __init__(
)

embedding_kwargs = embedding_kwargs or {}
self.embedding_kwargs = embedding_kwargs
self.time_emb = FourierEmbedding(**embedding_kwargs)
self.time_emb_dim = self.time_emb.embed_dim

Expand All @@ -124,6 +123,8 @@ def get_config(self):
config = {
"subnet": self.subnet,
"sigma": self.sigma,
"embedding_kwargs": self.embedding_kwargs,
"concatenate_subnet_input": self._concatenate_subnet_input,
}

return base_config | serialize(config)
Expand Down Expand Up @@ -151,17 +152,22 @@ def build(self, xz_shape, conditions_shape=None):
# construct input shape for subnet and subnet projector
input_shape = list(xz_shape)

# time vector
input_shape[-1] += self.time_emb_dim + 1

if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]

input_shape = tuple(input_shape)

self.subnet.build(input_shape)

input_shape = self.subnet.compute_output_shape(input_shape)
if self._concatenate_subnet_input:
# construct time vector
input_shape[-1] += self.time_emb_dim + 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]
input_shape = tuple(input_shape)

self.subnet.build(input_shape)
input_shape = self.subnet.compute_output_shape(input_shape)
else:
# Multiple separate inputs
time_shape = tuple(xz_shape[:-1]) + (self.time_emb_dim + 1,) # same batch/sequence dims, 1 feature
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
input_shape = self.subnet.compute_output_shape(
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
)
self.subnet_projector.build(input_shape)

# input shape for time embedding
Expand All @@ -173,6 +179,35 @@ def build(self, xz_shape, conditions_shape=None):
input_shape = self.weight_fn.compute_output_shape(input_shape)
self.weight_fn_projector.build(input_shape)

def _apply_subnet(
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
the time `t`, and optional conditions or by returning them separately.

Parameters
----------
x : Tensor
The parameter tensor, typically of shape (..., D), but can vary.
t : Tensor
The time tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The output tensor from the subnet.
"""
if self._concatenate_subnet_input:
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
return self.subnet(x=x, t=t, conditions=conditions, training=training)

def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
# Consistency Models only learn the direction from noise distribution
# to target distribution, so we cannot implement this function.
Expand Down Expand Up @@ -218,7 +253,6 @@ def consistency_function(
t: Tensor,
conditions: Tensor = None,
training: bool = False,
**kwargs,
) -> Tensor:
"""Compute consistency function at time t.

Expand All @@ -235,8 +269,8 @@ def consistency_function(
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the inner network.
"""
xtc = concatenate_valid([x / self.sigma, self.time_emb(t), conditions], axis=-1)
f = self.subnet_projector(self.subnet(xtc, training=training, **kwargs))
subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=training)
f = self.subnet_projector(subnet_out)
out = ops.cos(t) * x - ops.sin(t) * self.sigma * f
return out

Expand Down Expand Up @@ -273,7 +307,7 @@ def compute_metrics(
r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here

def f_teacher(x, t):
o = self.subnet(concatenate_valid([x, self.time_emb(t), conditions], axis=-1), training=stage == "training")
o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training")
return self.subnet_projector(o)

primals = (xt / self.sigma, t)
Expand All @@ -287,8 +321,8 @@ def f_teacher(x, t):
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)

# calculate output of the network
xtc = concatenate_valid([xt / self.sigma, self.time_emb(t), conditions], axis=-1)
student_out = self.subnet_projector(self.subnet(xtc, training=stage == "training"))
subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training")
student_out = self.subnet_projector(subnet_out)

# calculate the tangent
g = -(ops.cos(t) ** 2) * (self.sigma * teacher_output - dxtdt) - r * ops.cos(t) * ops.sin(t) * (
Expand Down
22 changes: 21 additions & 1 deletion bayesflow/networks/embeddings/fourier_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from keras import ops

from bayesflow.types import Tensor
from bayesflow.utils.serialization import serializable
from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import serializable, serialize, deserialize


@serializable("bayesflow.networks")
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(
self.scale = scale
self.embed_dim = embed_dim
self.include_identity = include_identity
self.initializer = initializer
self.trainable = trainable

def call(self, t: Tensor) -> Tensor:
"""Embeds the one-dimensional time scalar into a higher-dimensional Fourier embedding.
Expand All @@ -68,3 +71,20 @@ def call(self, t: Tensor) -> Tensor:
else:
emb = ops.concatenate([ops.sin(proj), ops.cos(proj)], axis=-1)
return emb

def get_config(self):
base_config = super().get_config()
base_config = layer_kwargs(base_config)

config = {
"embed_dim": self.embed_dim,
"scale": self.scale,
"initializer": self.initializer,
"trainable": self.trainable,
"include_identity": self.include_identity,
}
return base_config | serialize(config)

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))
Loading