Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@

@serializable("bayesflow.networks")
class FlowMatching(InferenceNetwork):
"""(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated
from [1-3].

[1] Rectified Flow: arXiv:2209.03003
[2] Flow Matching: arXiv:2210.02747
[3] Optimal Transport Flow Matching: arXiv:2302.00482
"""(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas
incorporated from [1-5].

[1] Liu et al. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow.
arXiv preprint arXiv:2209.03003.
[2] Lipman et al. (2022). Flow matching for generative modeling.
arXiv preprint arXiv:2210.02747.
[3] Tong et al. (2023). Improving and generalizing flow-based generative models with minibatch optimal transport.
arXiv preprint arXiv:2302.00482.
[4] Wildberger et al. (2023). Flow matching for scalable simulation-based inference.
Advances in Neural Information Processing Systems, 36, 16837-16864.
[5] Orsini et al. (2025). Flow matching posterior estimation for simulation-based atmospheric retrieval of
exoplanets. IEEE Access.
"""

MLP_DEFAULT_CONFIG = {
Expand Down Expand Up @@ -59,6 +66,7 @@ def __init__(
integrate_kwargs: dict[str, any] = None,
optimal_transport_kwargs: dict[str, any] = None,
subnet_kwargs: dict[str, any] = None,
time_power_law_alpha: float = 0.0,
**kwargs,
):
"""
Expand Down Expand Up @@ -96,6 +104,9 @@ def __init__(
into a single vector or passed as separate arguments. If set to False, the subnet
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
and optional 'conditions'. Default is True.
time_power_law_alpha: float, optional
Changes the distribution of sampled times during training. Time is sampled from a power law distribution
p(t) ∝ t^(1/(1+α)), where α is the provided value. Default is α=0, which corresponds to uniform sampling.
**kwargs
Additional keyword arguments passed to the subnet and other components.
"""
Expand All @@ -107,6 +118,9 @@ def __init__(
self.optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {})

self.loss_fn = keras.losses.get(loss_fn)
self.time_power_law_alpha = float(time_power_law_alpha)
if self.time_power_law_alpha <= -1.0:
raise ValueError("'time_power_law_alpha' must be greater than -1.0.")

self.seed_generator = keras.random.SeedGenerator()

Expand Down Expand Up @@ -164,6 +178,7 @@ def get_config(self):
"integrate_kwargs": self.integrate_kwargs,
"optimal_transport_kwargs": self.optimal_transport_kwargs,
"concatenate_subnet_input": self._concatenate_subnet_input,
"time_power_law_alpha": self.time_power_law_alpha,
# we do not need to store subnet_kwargs
}

Expand Down Expand Up @@ -307,7 +322,9 @@ def compute_metrics(
# conditions must be resampled along with x1
conditions = keras.ops.take(conditions, assignments, axis=0)

t = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
u = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
# p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform
t = u ** (1 + self.time_power_law_alpha)
t = expand_right_as(t, x0)

x = t * x1 + (1 - t) * x0
Expand Down