diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index da4acd321..59deab53b 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -20,12 +20,19 @@ @serializable("bayesflow.networks") class FlowMatching(InferenceNetwork): - """(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated - from [1-3]. - - [1] Rectified Flow: arXiv:2209.03003 - [2] Flow Matching: arXiv:2210.02747 - [3] Optimal Transport Flow Matching: arXiv:2302.00482 + """(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas + incorporated from [1-5]. + + [1] Liu et al. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow. + arXiv preprint arXiv:2209.03003. + [2] Lipman et al. (2022). Flow matching for generative modeling. + arXiv preprint arXiv:2210.02747. + [3] Tong et al. (2023). Improving and generalizing flow-based generative models with minibatch optimal transport. + arXiv preprint arXiv:2302.00482. + [4] Wildberger et al. (2023). Flow matching for scalable simulation-based inference. + Advances in Neural Information Processing Systems, 36, 16837-16864. + [5] Orsini et al. (2025). Flow matching posterior estimation for simulation-based atmospheric retrieval of + exoplanets. IEEE Access. """ MLP_DEFAULT_CONFIG = { @@ -59,6 +66,7 @@ def __init__( integrate_kwargs: dict[str, any] = None, optimal_transport_kwargs: dict[str, any] = None, subnet_kwargs: dict[str, any] = None, + time_power_law_alpha: float = 0.0, **kwargs, ): """ @@ -96,6 +104,9 @@ def __init__( into a single vector or passed as separate arguments. If set to False, the subnet must accept three separate inputs: 'x' (noisy parameters), 't' (time), and optional 'conditions'. Default is True. + time_power_law_alpha: float, optional + Changes the distribution of sampled times during training. Time is sampled from a power law distribution + p(t) ∝ t^(1/(1+α)), where α is the provided value. Default is α=0, which corresponds to uniform sampling. **kwargs Additional keyword arguments passed to the subnet and other components. """ @@ -107,6 +118,9 @@ def __init__( self.optimal_transport_kwargs = FlowMatching.OPTIMAL_TRANSPORT_DEFAULT_CONFIG | (optimal_transport_kwargs or {}) self.loss_fn = keras.losses.get(loss_fn) + self.time_power_law_alpha = float(time_power_law_alpha) + if self.time_power_law_alpha <= -1.0: + raise ValueError("'time_power_law_alpha' must be greater than -1.0.") self.seed_generator = keras.random.SeedGenerator() @@ -164,6 +178,7 @@ def get_config(self): "integrate_kwargs": self.integrate_kwargs, "optimal_transport_kwargs": self.optimal_transport_kwargs, "concatenate_subnet_input": self._concatenate_subnet_input, + "time_power_law_alpha": self.time_power_law_alpha, # we do not need to store subnet_kwargs } @@ -307,7 +322,9 @@ def compute_metrics( # conditions must be resampled along with x1 conditions = keras.ops.take(conditions, assignments, axis=0) - t = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator) + u = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator) + # p(t) ∝ t^(1/(1+α)), the inverse CDF: F^(-1)(u) = u^(1+α), α=0 is uniform + t = u ** (1 + self.time_power_law_alpha) t = expand_right_as(t, x0) x = t * x1 + (1 - t) * x0