Skip to content

How to estimate intractable likelihood functions with normalizing flows.  #791

@itsdfish

Description

@itsdfish

I am opening an issue based on this discourse thread in which I ask about estimating intractable likelihood functions with normalizing flows. As described in this paper, normalizing flows can approximate the likelihood function of a model by learning the relationship between model parameters and the output of a model. Having a working example would be of broad interest given that many scientific fields work with complex models for which a likelihood function is unknown.

The paper is associated with a Python package called SBI. Here is a simple working example based on a LogNormal distribution. I pasted my attempt at replicating it in Julia below. Note that the package generalizes to processes that emit multiple distributions, but I have used a single distribution for simplicity.

The architectural details in the article were a little sparse, citing:

For the neural spline flow architecture (Durkan et al., 2019), we transformed the reaction time data to the log-domain, used a standard normal base distribution, 2 spline transforms with 5 bins each and conditioning networks with 3 hidden layers and 10 hidden units each, and rectified linear unit activation functions. The neural network training was performed using the sbi package with the following settings: learning rate 0.0005; training batch size 100; 10% of training data as validation data, stop training after 20 epochs without validation loss improvement.

It appears that the architecture is heavily influenced by Sequential Neural Likelihood:
Fast Likelihood-free Inference with Autoregressive Flows
. What appears to be the core Python code can be found here.

Also for additional background, Hossein has started a related package, but it is experimental and has no documentation.

Thank you for looking into this. I don't know much about neural network and Flux, but let me know if I can be helpful at all.

WIP Code


###########################################################################################################
#                                           load packages
###########################################################################################################
cd(@__DIR__)
using Pkg 
Pkg.activate("")
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux
using OptimizationOptimJL, Distributions
using Random
Random.seed!(3411)
###########################################################################################################
#                                           generate training data
###########################################################################################################
n_parms = 10_000
# training parameters
train_parms = map(x -> Float32.(rand(Gamma(1.0, .5), 2)), 1:n_parms)
# training samples
samples = map(p -> Float32(rand(LogNormal(p...))), train_parms)
training_data = [hcat(train_parms...); samples']
###########################################################################################################
#                                           setup network
###########################################################################################################
nn = Flux.Chain(
    # inputs are parameters μ and σ, and distribution sample
    Flux.Dense(3, 10, tanh),
    Flux.Dense(10, 10, tanh),
    Flux.Dense(10, 1, tanh),
) |> f32
tspan = (0.0f0, 50.0f0)

ffjord_mdl = DiffEqFlux.FFJORD(nn, tspan, Tsit5())

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

function cb(p, l)::Bool
    vl = loss(p)
    @info "Training" loss = vl
    false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)

res1 = Optimization.solve(optprob,
                          ADAM(0.1),
                          maxiters = 100, callback=cb)

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
                          Optim.LBFGS(),
                          allow_f_increases=false, callback=cb)
###########################################################################################################
#                                           evaluate and plot
###########################################################################################################
using Plots

test_parms = [1.5, .5]
xs = [.05:.05:20;]
true_density = pdf.(LogNormal(test_parms...), xs)
# is there a better way to get the estimated density?
est_density = map(x -> exp(ffjord_mdl([x], res2.u, monte_carlo=false)[1]), xs)
est_density = vcat(est_density...)
# plot the true and estimated densities
plot(xs, true_density)
plot!(xs, est_density)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions