Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,20 @@ def build_adapter(
summary_variables : Sequence of str, optional
Names of the summary variables in the data
"""
adapter = Adapter.create_default(inference_variables)
adapter = Adapter()
adapter.to_array()
adapter.convert_dtype("float64", "float32")
adapter.concatenate(inference_variables, into="inference_variables")

if inference_conditions is not None:
adapter = adapter.concatenate(inference_conditions, into="inference_conditions")
adapter.concatenate(inference_conditions, into="inference_conditions")

if summary_variables is not None:
adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables")
adapter.as_set(summary_variables)
adapter.concatenate(summary_variables, into="summary_variables")

adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize()
adapter.keep(["inference_variables", "inference_conditions", "summary_variables"])
adapter.standardize()

return adapter

Expand Down
8 changes: 4 additions & 4 deletions bayesflow/experimental/continuous_time_consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bayesflow.types import Tensor
from bayesflow.utils import (
jvp,
concatenate,
concatenate_valid,
find_network,
keras_kwargs,
expand_right_as,
Expand Down Expand Up @@ -201,7 +201,7 @@ def consistency_function(
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the inner network.
"""
xtc = concatenate(x / self.sigma_data, self.time_emb(t), conditions, axis=-1)
xtc = concatenate_valid([x / self.sigma_data, self.time_emb(t), conditions], axis=-1)
f = self.subnet_projector(self.subnet(xtc, training=training, **kwargs))
out = ops.cos(t) * x - ops.sin(t) * self.sigma_data * f
return out
Expand Down Expand Up @@ -240,7 +240,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here

def f_teacher(x, t):
o = self.subnet(concatenate(x, self.time_emb(t), conditions, axis=-1), training=stage == "training")
o = self.subnet(concatenate_valid([x, self.time_emb(t), conditions], axis=-1), training=stage == "training")
return self.subnet_projector(o)

primals = (xt / self.sigma_data, t)
Expand All @@ -254,7 +254,7 @@ def f_teacher(x, t):
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)

# calculate output of the network
xtc = concatenate(xt / self.sigma_data, self.time_emb(t), conditions, axis=-1)
xtc = concatenate_valid([xt / self.sigma_data, self.time_emb(t), conditions], axis=-1)
student_out = self.subnet_projector(self.subnet(xtc, training=stage == "training"))

# calculate the tangent
Expand Down
6 changes: 3 additions & 3 deletions bayesflow/experimental/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bayesflow.utils import (
find_network,
keras_kwargs,
concatenate,
concatenate_valid,
jacobian,
jvp,
vjp,
Expand Down Expand Up @@ -181,7 +181,7 @@ def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, *
if conditions is None:
inp = x
else:
inp = concatenate(x, conditions, axis=-1)
inp = concatenate_valid([x, conditions], axis=-1)
network_out = self.encoder_projector(
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
)
Expand All @@ -191,7 +191,7 @@ def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, *
if conditions is None:
inp = z
else:
inp = concatenate(z, conditions, axis=-1)
inp = concatenate_valid([z, conditions], axis=-1)
network_out = self.decoder_projector(
self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
)
Expand Down
60 changes: 26 additions & 34 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,19 @@ class FlowMatching(InferenceNetwork):
}

INTEGRATE_DEFAULT_CONFIG = {
"method": "rk45",
"steps": "adaptive",
"tolerance": 1e-3,
"min_steps": 10,
"max_steps": 100,
"method": "euler",
"steps": 100,
}

def __init__(
self,
subnet: str | type = "mlp",
base_distribution: str = "normal",
use_optimal_transport: bool = False,
use_optimal_transport: bool = True,
loss_fn: str = "mse",
integrate_kwargs: dict[str, any] = None,
optimal_transport_kwargs: dict[str, any] = None,
subnet_kwargs: dict[str, any] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -97,23 +95,17 @@ def __init__(

self.use_optimal_transport = use_optimal_transport

new_integrate_kwargs = FlowMatching.INTEGRATE_DEFAULT_CONFIG.copy()
new_integrate_kwargs.update(integrate_kwargs or {})
self.integrate_kwargs = new_integrate_kwargs

new_optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG.copy()
new_optimal_transport_kwargs.update(optimal_transport_kwargs or {})
self.optimal_transport_kwargs = new_optimal_transport_kwargs
self.integrate_kwargs = FlowMatching.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
self.optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})

self.loss_fn = keras.losses.get(loss_fn)

self.seed_generator = keras.random.SeedGenerator()

subnet_kwargs = subnet_kwargs or {}

if subnet == "mlp":
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG.copy()
subnet_kwargs.update(kwargs.get("subnet_kwargs", {}))
else:
subnet_kwargs = kwargs.get("subnet_kwargs", {})
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs

self.subnet = find_network(subnet, **subnet_kwargs)
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")
Expand Down Expand Up @@ -154,23 +146,23 @@ def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
t = keras.ops.convert_to_tensor(t)
t = expand_right_as(t, xz)
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))
def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))
time = expand_right_as(time, xz)
time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))

if conditions is None:
xtc = keras.ops.concatenate([xz, t], axis=-1)
xtc = keras.ops.concatenate([xz, time], axis=-1)
else:
xtc = keras.ops.concatenate([xz, t, conditions], axis=-1)
xtc = keras.ops.concatenate([xz, time, conditions], axis=-1)

return self.output_projector(self.subnet(xtc, training=training), training=training)

def _velocity_trace(
self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
) -> (Tensor, Tensor):
def f(x):
return self.velocity(x, t, conditions=conditions, training=training)
return self.velocity(x, time=time, conditions=conditions, training=training)

v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)

Expand All @@ -181,8 +173,8 @@ def _forward(
) -> Tensor | tuple[Tensor, Tensor]:
if density:

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
def deltas(time, xz):
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
Expand All @@ -193,8 +185,8 @@ def deltas(t, xz):

return z, log_density

def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}
def deltas(time, xz):
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}

state = {"xz": x}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
Expand All @@ -208,8 +200,8 @@ def _inverse(
) -> Tensor | tuple[Tensor, Tensor]:
if density:

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
def deltas(time, xz):
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
Expand All @@ -220,8 +212,8 @@ def deltas(t, xz):

return x, log_density

def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}
def deltas(time, xz):
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}

state = {"xz": z}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
Expand Down Expand Up @@ -258,7 +250,7 @@ def compute_metrics(

base_metrics = super().compute_metrics(x1, conditions, stage)

predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")

loss = self.loss_fn(target_velocity, predicted_velocity)
loss = keras.ops.mean(loss)
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
from .serialization import serialize_value_or_type, deserialize_value_or_type
from .tensor_utils import (
concatenate,
concatenate_valid,
expand,
expand_as,
expand_to,
Expand All @@ -59,12 +59,13 @@
expand_right_as,
expand_right_to,
expand_tile,
pad,
searchsorted,
size_of,
stack_valid,
tile_axis,
tree_concatenate,
tree_stack,
pad,
searchsorted,
)
from .validators import check_lengths_same
from .workflow_utils import find_inference_network, find_summary_network
Expand Down
11 changes: 5 additions & 6 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def convert_kwargs(f: Callable, *args: any, **kwargs: any) -> dict[str, any]:
return parameters


def filter_kwargs(kwargs: Mapping[str, T], f: Callable) -> Mapping[str, T]:
def filter_kwargs(kwargs: dict[str, T], f: Callable) -> dict[str, T]:
"""Filter keyword arguments for f"""
signature = inspect.signature(f)

Expand All @@ -63,11 +63,10 @@ def filter_kwargs(kwargs: Mapping[str, T], f: Callable) -> Mapping[str, T]:
return kwargs


def keras_kwargs(kwargs: Mapping[str, T]) -> dict[str, T]:
"""Keep dictionary keys that do not end with _kwargs. Used for propagating
keyword arguments in nested layer classes.
"""
return {key: value for key, value in kwargs.items() if not key.endswith("_kwargs")}
def keras_kwargs(kwargs: dict[str, T]) -> dict[str, T]:
"""Filter keyword arguments for keras.Layer"""
valid_keys = ["dtype", "name", "trainable"]
return {key: value for key, value in kwargs.items() if key in valid_keys}


# TODO: rename and streamline and make protected
Expand Down
6 changes: 4 additions & 2 deletions bayesflow/utils/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import keras

from typing import Literal

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from . import logging
Expand Down Expand Up @@ -238,8 +240,8 @@ def integrate(
stop_time: ArrayLike,
min_steps: int = 10,
max_steps: int = 10_000,
steps: int = "adaptive",
method: str = "rk45",
steps: int | Literal["adaptive"] = 100,
method: str = "euler",
**kwargs,
) -> dict[str, ArrayLike]:
match steps:
Expand Down
8 changes: 6 additions & 2 deletions bayesflow/utils/optimal_transport/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..dispatch import find_cost
from .. import logging
from ..numpy_utils import softmax
from ..tensor_utils import is_symbolic_tensor


def sinkhorn(
Expand Down Expand Up @@ -134,8 +135,8 @@ def sinkhorn_indices(
rng = np.random.default_rng(seed)

indices = []
for row in range(cost.shape[0]):
index = rng.choice(cost.shape[1], p=plan[row])
for row in range(plan.shape[0]):
index = rng.choice(plan.shape[1], p=plan[row])
indices.append(index)

indices = np.array(indices)
Expand Down Expand Up @@ -190,6 +191,9 @@ def sinkhorn_plan_keras(cost: Tensor, regularization: float, max_steps: int, tol
# initialize the transport plan from a gaussian kernel
plan = keras.ops.exp(-0.5 * cost / regularization)

if is_symbolic_tensor(plan):
return plan

def is_converged(plan):
# check convergence: the plan should be doubly stochastic
marginals = keras.ops.sum(plan, axis=0), keras.ops.sum(plan, axis=1)
Expand Down
Loading