Skip to content

Discussion: Adding SciPy integrators #551

@vpratz

Description

@vpratz

Our integration suite currently lacks an adaptive method that allows specifying tolerances. Having control over the error is especially important when the results are used in comparisons (e.g., the log-prob in diffusion models and flow matching).
As they would mainly be used for evaluation, they would not necessarily need to be differentiable, which would allow us to use third-party integrators as well. I think SciPy's scipy.integrate.solve_ivp would be the most obvious choice for this. An implementation could look approximately like this (with more careful flattening/reshaping, so that it works in a general setting):

import numpy as np
import bayesflow as bf

from collections.abc import Callable
import scipy.integrate


def integrate_scipy(
    fn: Callable,
    state: dict,
    start_time: np.ndarray,
    stop_time: np.ndarray,
    method: str = "RK45",
    atol: float = 1e-6,
    rtol: float = 1e-3,
    **kwargs,
):
    adapter = (
        bf.Adapter()
        .concatenate(list(state.keys()), into="x", axis=-1)
        .convert_dtype(np.float32, np.float64)
    )
    initial_state = adapter.forward(state)["x"]
    shape = initial_state.shape

    def scipy_wrapper_fn(time, x):
        state = adapter.inverse({"x": x.reshape(shape)})
        state = keras.tree.map_structure(keras.ops.convert_to_tensor, state)
        time = keras.ops.convert_to_tensor(time, dtype="float32")
        deltas = fn(time, **bf.utils.filter_kwargs(state, fn))
        deltas = keras.tree.map_structure(keras.ops.convert_to_numpy, deltas)
        return adapter.forward(deltas)["x"].reshape(-1)

    res = scipy.integrate.solve_ivp(
        scipy_wrapper_fn,
        (start_time, stop_time),
        initial_state.reshape(-1),
        method=method,
        atol=atol,
        rtol=rtol,
    )
    return adapter.inverse({"x": res.y[:,-1].reshape(shape)})

Regarding the interface, we could allow passing the method "scipy" and allow passing kwargs to scipy.integrate.solve_ivp.

Tagging @LarsKue and @stefanradev93. Would you welcome such a change, and if so, what would be your preferred interface?

Edit: For a trained approximator using a diffusion model or flow matching, you can test it like this:

approximator.inference_network.integrate_kwargs['steps'] = 100
approximator.inference_network.integrate_kwargs['atol'] = 1e-5
approximator.inference_network.integrate_kwargs['rtol'] = 1e-5
approximator.inference_network.integrate_kwargs['method'] = "RK45"

bf.networks.diffusion_model.diffusion_model.integrate = integrate_scipy
bf.networks.flow_matching.flow_matching.integrate = integrate_scipy
log_prob = approximator.log_prob(data=dataset)

Metadata

Metadata

Assignees

No one assigned

    Labels

    discussionDiscuss a topic or question not necessarily with a clear output in mind.featureNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions